File size: 4,626 Bytes
a366dd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Module containing utilities for NDFrame.sample() and .GroupBy.sample()
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from pandas._libs import lib
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
if TYPE_CHECKING:
from pandas._typing import AxisInt
from pandas.core.generic import NDFrame
def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray:
"""
Process and validate the `weights` argument to `NDFrame.sample` and
`.GroupBy.sample`.
Returns `weights` as an ndarray[np.float64], validated except for normalizing
weights (because that must be done groupwise in groupby sampling).
"""
# If a series, align with frame
if isinstance(weights, ABCSeries):
weights = weights.reindex(obj.axes[axis])
# Strings acceptable if a dataframe and axis = 0
if isinstance(weights, str):
if isinstance(obj, ABCDataFrame):
if axis == 0:
try:
weights = obj[weights]
except KeyError as err:
raise KeyError(
"String passed to weights not a valid column"
) from err
else:
raise ValueError(
"Strings can only be passed to "
"weights when sampling from rows on "
"a DataFrame"
)
else:
raise ValueError(
"Strings cannot be passed as weights when sampling from a Series."
)
if isinstance(obj, ABCSeries):
func = obj._constructor
else:
func = obj._constructor_sliced
weights = func(weights, dtype="float64")._values
if len(weights) != obj.shape[axis]:
raise ValueError("Weights and axis to be sampled must be of same length")
if lib.has_infs(weights):
raise ValueError("weight vector may not include `inf` values")
if (weights < 0).any():
raise ValueError("weight vector many not include negative values")
missing = np.isnan(weights)
if missing.any():
# Don't modify weights in place
weights = weights.copy()
weights[missing] = 0
return weights
def process_sampling_size(
n: int | None, frac: float | None, replace: bool
) -> int | None:
"""
Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
`.GroupBy.sample`.
Returns None if `frac` should be used (variable sampling sizes), otherwise returns
the constant sampling size.
"""
# If no frac or n, default to n=1.
if n is None and frac is None:
n = 1
elif n is not None and frac is not None:
raise ValueError("Please enter a value for `frac` OR `n`, not both")
elif n is not None:
if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide `n` >= 0."
)
if n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
else:
assert frac is not None # for mypy
if frac > 1 and not replace:
raise ValueError(
"Replace has to be set to `True` when "
"upsampling the population `frac` > 1."
)
if frac < 0:
raise ValueError(
"A negative number of rows requested. Please provide `frac` >= 0."
)
return n
def sample(
obj_len: int,
size: int,
replace: bool,
weights: np.ndarray | None,
random_state: np.random.RandomState | np.random.Generator,
) -> np.ndarray:
"""
Randomly sample `size` indices in `np.arange(obj_len)`
Parameters
----------
obj_len : int
The length of the indices being considered
size : int
The number of values to choose
replace : bool
Allow or disallow sampling of the same row more than once.
weights : np.ndarray[np.float64] or None
If None, equal probability weighting, otherwise weights according
to the vector normalized
random_state: np.random.RandomState or np.random.Generator
State used for the random sampling
Returns
-------
np.ndarray[np.intp]
"""
if weights is not None:
weight_sum = weights.sum()
if weight_sum != 0:
weights = weights / weight_sum
else:
raise ValueError("Invalid weights: weights sum to zero")
return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
np.intp, copy=False
)
|