Spaces:
Sleeping
Sleeping
File size: 5,969 Bytes
3ca4c9f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer # <-- REQUIRED IMPORT
from sklearn.preprocessing import OneHotEncoder # <-- REQUIRED IMPORT
import pandas as pd
import numpy as np # <-- REQUIRED IMPORT
from typing import Optional, Iterable, Any # <-- REQUIRED IMPORTS for type hinting
# Define the custom transformer class
class ManualProductTypeMapper(BaseEstimator, TransformerMixin):
"""
Transformer that maps values of a Product-Type column to a controlled set of
allowed categories, mapping all other (unwanted / rare / unknown) values to 'Others'.
"""
def __init__(self, product_col: str = 'Product_Type', keep_set: Optional[Iterable[str]] = None):
# Store constructor arguments exactly as provided.
self.product_col = product_col
self.keep_set = keep_set
def fit(self, X: pd.DataFrame, y: Optional[Any] = None):
"""
Validate inputs and prepare internal state.
"""
# Basic input validation
if not isinstance(X, pd.DataFrame):
raise ValueError("fit expects X to be a pandas DataFrame")
if self.product_col not in X.columns:
raise ValueError(f"product_col '{self.product_col}' not found in X during fit")
# keep_set must be provided by user; convert into an internal set for fast membership tests
if self.keep_set is None:
raise ValueError("ManualProductTypeMapper requires a non-empty keep_set (pass an iterable of values)")
# Create a defensive copy and ensure type is set
self.keep_set_ = set(self.keep_set)
return self
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Map values not in keep_set_ to 'Others'.
"""
# Ensure fit has been called
if not hasattr(self, 'keep_set_'):
raise ValueError("transform called before fit(). Call fit(X) first.")
if not isinstance(X, pd.DataFrame):
raise ValueError("transform expects a pandas DataFrame")
if self.product_col not in X.columns:
raise ValueError(f"product_col '{self.product_col}' not found in X during transform")
# Work on a shallow copy to avoid mutating the user's DataFrame
X2 = X.copy()
# Define the function for mapping to 'Others'
def mapper_func(v):
return v if v in self.keep_set_ else 'Others'
# Apply the mapping
X2[self.product_col] = X2[self.product_col].apply(mapper_func)
return X2
def fit_transform(self, X: pd.DataFrame, y: Optional[Any] = None, **fit_params) -> pd.DataFrame:
"""
Fit the transformer and transform X in one step.
Additionally ensures that the transformed training data contains at least one
row with Product_Type == 'Others' for downstream OneHotEncoder compatibility.
"""
# Fit to create keep_set_
self.fit(X, y)
# Apply mapping to the data
X_trans = self.transform(X)
# If 'Others' already present, return transformed data as-is
if 'Others' in X_trans[self.product_col].unique():
return X_trans
# Build a synthetic row with Product_Type='Others'
synthetic: dict = {}
for col in X_trans.columns:
if col == self.product_col:
synthetic[col] = 'Others' # ensure 'Others' exists
else:
# Choose a safe default: mode for categorical, median for numeric
ser = X_trans[col].dropna()
if ser.empty:
synthetic[col] = np.nan
else:
# Check for categorical/object/string-like data
if pd.api.types.is_object_dtype(ser) or pd.api.types.is_categorical_dtype(ser) or pd.api.types.is_string_dtype(ser):
synthetic[col] = ser.mode().iloc[0]
else:
# Numeric fallback: ensure the median is a native Python type if possible, or NumPy float
synthetic[col] = float(ser.median()) if pd.api.types.is_numeric_dtype(ser) else ser.iloc[0] # Take first non-empty if non-numeric/non-mode
synthetic_df = pd.DataFrame([synthetic], columns=X_trans.columns)
# Append the synthetic row and return the augmented DataFrame
X_with_dummy = pd.concat([X_trans, synthetic_df], ignore_index=True)
return X_with_dummy
# ------------------ Hard-coded keep list (edit as needed) ------------------
# Define the KEEP_PRODUCT_TYPES set.
KEEP_PRODUCT_TYPES = {
'Fruits and Vegetables', 'Snack Foods', 'Dairy', 'Frozen Foods', 'Household',
'Baking Goods', 'Canned', 'Health and Hygiene', 'Meat', 'Soft Drinks'
}
# ------------------ Example of Use (NOT part of the final pipeline object itself) ------------------
# NOTE: The variables 'cat_cols' would need to be defined outside this file
# or imported if they are used to build the ColumnTransformer.
# Example usage (commented out as these variables are undefined in this file scope):
# cat_cols = ['Store_Type', 'Store_Location_Type', 'Store_Size', 'Product_Type']
#
# # Step 1: Custom transformer that groups rare Product_Type values into 'Others'
# mapper = ManualProductTypeMapper(
# product_col='Product_Type',
# keep_set=KEEP_PRODUCT_TYPES # your manually defined keep list
# )
#
# # Step 2: Define how categorical columns should be encoded
# col_transformer = ColumnTransformer(
# transformers=[
# # Use the mapper *before* the OneHotEncoder if the mapper is placed *inside* a Pipeline
# # Here, we assume the mapper runs *before* this ColumnTransformer in the main pipeline.
# ('ohe_cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'), cat_cols),
# ],
# remainder='passthrough'
# )
|