File size: 5,969 Bytes
3ca4c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer # <-- REQUIRED IMPORT
from sklearn.preprocessing import OneHotEncoder # <-- REQUIRED IMPORT
import pandas as pd
import numpy as np # <-- REQUIRED IMPORT
from typing import Optional, Iterable, Any # <-- REQUIRED IMPORTS for type hinting

# Define the custom transformer class
class ManualProductTypeMapper(BaseEstimator, TransformerMixin):
    """
    Transformer that maps values of a Product-Type column to a controlled set of
    allowed categories, mapping all other (unwanted / rare / unknown) values to 'Others'.
    """

    def __init__(self, product_col: str = 'Product_Type', keep_set: Optional[Iterable[str]] = None):
        # Store constructor arguments exactly as provided.
        self.product_col = product_col
        self.keep_set = keep_set


    def fit(self, X: pd.DataFrame, y: Optional[Any] = None):
        """
        Validate inputs and prepare internal state.
        """
        # Basic input validation
        if not isinstance(X, pd.DataFrame):
            raise ValueError("fit expects X to be a pandas DataFrame")
        if self.product_col not in X.columns:
            raise ValueError(f"product_col '{self.product_col}' not found in X during fit")

        # keep_set must be provided by user; convert into an internal set for fast membership tests
        if self.keep_set is None:
            raise ValueError("ManualProductTypeMapper requires a non-empty keep_set (pass an iterable of values)")
        
        # Create a defensive copy and ensure type is set
        self.keep_set_ = set(self.keep_set)

        return self

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        """
        Map values not in keep_set_ to 'Others'.
        """
        # Ensure fit has been called
        if not hasattr(self, 'keep_set_'):
            raise ValueError("transform called before fit(). Call fit(X) first.")

        if not isinstance(X, pd.DataFrame):
            raise ValueError("transform expects a pandas DataFrame")
        if self.product_col not in X.columns:
            raise ValueError(f"product_col '{self.product_col}' not found in X during transform")

        # Work on a shallow copy to avoid mutating the user's DataFrame
        X2 = X.copy()
        
        # Define the function for mapping to 'Others'
        def mapper_func(v):
            return v if v in self.keep_set_ else 'Others'
            
        # Apply the mapping
        X2[self.product_col] = X2[self.product_col].apply(mapper_func)
        return X2

    def fit_transform(self, X: pd.DataFrame, y: Optional[Any] = None, **fit_params) -> pd.DataFrame:
        """
        Fit the transformer and transform X in one step.
        Additionally ensures that the transformed training data contains at least one
        row with Product_Type == 'Others' for downstream OneHotEncoder compatibility.
        """
        # Fit to create keep_set_
        self.fit(X, y)
        # Apply mapping to the data
        X_trans = self.transform(X)

        # If 'Others' already present, return transformed data as-is
        if 'Others' in X_trans[self.product_col].unique():
            return X_trans

        # Build a synthetic row with Product_Type='Others'
        synthetic: dict = {}
        for col in X_trans.columns:
            if col == self.product_col:
                synthetic[col] = 'Others'  # ensure 'Others' exists
            else:
                # Choose a safe default: mode for categorical, median for numeric
                ser = X_trans[col].dropna()
                
                if ser.empty:
                    synthetic[col] = np.nan
                else:
                    # Check for categorical/object/string-like data
                    if pd.api.types.is_object_dtype(ser) or pd.api.types.is_categorical_dtype(ser) or pd.api.types.is_string_dtype(ser):
                        synthetic[col] = ser.mode().iloc[0]
                    else:
                        # Numeric fallback: ensure the median is a native Python type if possible, or NumPy float
                        synthetic[col] = float(ser.median()) if pd.api.types.is_numeric_dtype(ser) else ser.iloc[0] # Take first non-empty if non-numeric/non-mode
        
        synthetic_df = pd.DataFrame([synthetic], columns=X_trans.columns)
        
        # Append the synthetic row and return the augmented DataFrame
        X_with_dummy = pd.concat([X_trans, synthetic_df], ignore_index=True)
        return X_with_dummy


# ------------------ Hard-coded keep list (edit as needed) ------------------
# Define the KEEP_PRODUCT_TYPES set.
KEEP_PRODUCT_TYPES = {
    'Fruits and Vegetables', 'Snack Foods', 'Dairy', 'Frozen Foods', 'Household',
    'Baking Goods', 'Canned', 'Health and Hygiene', 'Meat', 'Soft Drinks'
}

# ------------------ Example of Use (NOT part of the final pipeline object itself) ------------------
# NOTE: The variables 'cat_cols' would need to be defined outside this file 
# or imported if they are used to build the ColumnTransformer.

# Example usage (commented out as these variables are undefined in this file scope):
# cat_cols = ['Store_Type', 'Store_Location_Type', 'Store_Size', 'Product_Type']
# 
# # Step 1: Custom transformer that groups rare Product_Type values into 'Others'
# mapper = ManualProductTypeMapper(
#     product_col='Product_Type',
#     keep_set=KEEP_PRODUCT_TYPES   # your manually defined keep list
# )
# 
# # Step 2: Define how categorical columns should be encoded
# col_transformer = ColumnTransformer(
#     transformers=[
#         # Use the mapper *before* the OneHotEncoder if the mapper is placed *inside* a Pipeline
#         # Here, we assume the mapper runs *before* this ColumnTransformer in the main pipeline.
#         ('ohe_cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'), cat_cols),
#     ],
#     remainder='passthrough'
# )