Varun6299 commited on
Commit
01ad5cf
·
verified ·
1 Parent(s): 4c38460

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_transformers.py +112 -0
custom_transformers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from sklearn.base import BaseEstimator, TransformerMixin
4
+ from sklearn.compose import ColumnTransformer # <-- REQUIRED IMPORT
5
+ from sklearn.preprocessing import OneHotEncoder # <-- REQUIRED IMPORT
6
+ import pandas as pd
7
+ import numpy as np # <-- REQUIRED IMPORT
8
+ from typing import Optional, Iterable, Any # <-- REQUIRED IMPORTS for type hinting
9
+
10
+ # Define the custom transformer class
11
+ class ManualProductTypeMapper(BaseEstimator, TransformerMixin):
12
+ """
13
+ Transformer that maps values of a Product-Type column to a controlled set of
14
+ allowed categories, mapping all other (unwanted / rare / unknown) values to 'Others'.
15
+ """
16
+
17
+ def __init__(self, product_col: str = 'Product_Type', keep_set: Optional[Iterable[str]] = None):
18
+ # Store constructor arguments exactly as provided.
19
+ self.product_col = product_col
20
+ self.keep_set = keep_set
21
+
22
+
23
+ def fit(self, X: pd.DataFrame, y: Optional[Any] = None):
24
+ """
25
+ Validate inputs and prepare internal state.
26
+ """
27
+ # Basic input validation
28
+ if not isinstance(X, pd.DataFrame):
29
+ raise ValueError("fit expects X to be a pandas DataFrame")
30
+ if self.product_col not in X.columns:
31
+ raise ValueError(f"product_col '{self.product_col}' not found in X during fit")
32
+
33
+ # keep_set must be provided by user; convert into an internal set for fast membership tests
34
+ if self.keep_set is None:
35
+ raise ValueError("ManualProductTypeMapper requires a non-empty keep_set (pass an iterable of values)")
36
+
37
+ # Create a defensive copy and ensure type is set
38
+ self.keep_set_ = set(self.keep_set)
39
+
40
+ return self
41
+
42
+ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
43
+ """
44
+ Map values not in keep_set_ to 'Others'.
45
+ """
46
+ # Ensure fit has been called
47
+ if not hasattr(self, 'keep_set_'):
48
+ raise ValueError("transform called before fit(). Call fit(X) first.")
49
+
50
+ if not isinstance(X, pd.DataFrame):
51
+ raise ValueError("transform expects a pandas DataFrame")
52
+ if self.product_col not in X.columns:
53
+ raise ValueError(f"product_col '{self.product_col}' not found in X during transform")
54
+
55
+ # Work on a shallow copy to avoid mutating the user's DataFrame
56
+ X2 = X.copy()
57
+
58
+ # Define the function for mapping to 'Others'
59
+ def mapper_func(v):
60
+ return v if v in self.keep_set_ else 'Others'
61
+
62
+ # Apply the mapping
63
+ X2[self.product_col] = X2[self.product_col].apply(mapper_func)
64
+ return X2
65
+
66
+ def fit_transform(self, X: pd.DataFrame, y: Optional[Any] = None, **fit_params) -> pd.DataFrame:
67
+ """
68
+ Fit the transformer and transform X in one step.
69
+ Additionally ensures that the transformed training data contains at least one
70
+ row with Product_Type == 'Others' for downstream OneHotEncoder compatibility.
71
+ """
72
+ # Fit to create keep_set_
73
+ self.fit(X, y)
74
+ # Apply mapping to the data
75
+ X_trans = self.transform(X)
76
+
77
+ # If 'Others' already present, return transformed data as-is
78
+ if 'Others' in X_trans[self.product_col].unique():
79
+ return X_trans
80
+
81
+ # Build a synthetic row with Product_Type='Others'
82
+ synthetic: dict = {}
83
+ for col in X_trans.columns:
84
+ if col == self.product_col:
85
+ synthetic[col] = 'Others' # ensure 'Others' exists
86
+ else:
87
+ # Choose a safe default: mode for categorical, median for numeric
88
+ ser = X_trans[col].dropna()
89
+
90
+ if ser.empty:
91
+ synthetic[col] = np.nan
92
+ else:
93
+ # Check for categorical/object/string-like data
94
+ if pd.api.types.is_object_dtype(ser) or pd.api.types.is_categorical_dtype(ser) or pd.api.types.is_string_dtype(ser):
95
+ synthetic[col] = ser.mode().iloc[0]
96
+ else:
97
+ # Numeric fallback: ensure the median is a native Python type if possible, or NumPy float
98
+ synthetic[col] = float(ser.median()) if pd.api.types.is_numeric_dtype(ser) else ser.iloc[0] # Take first non-empty if non-numeric/non-mode
99
+
100
+ synthetic_df = pd.DataFrame([synthetic], columns=X_trans.columns)
101
+
102
+ # Append the synthetic row and return the augmented DataFrame
103
+ X_with_dummy = pd.concat([X_trans, synthetic_df], ignore_index=True)
104
+ return X_with_dummy
105
+
106
+
107
+ # ------------------ Hard-coded keep list (edit as needed) ------------------
108
+ # Define the KEEP_PRODUCT_TYPES set.
109
+ KEEP_PRODUCT_TYPES = {
110
+ 'Fruits and Vegetables', 'Snack Foods', 'Dairy', 'Frozen Foods', 'Household',
111
+ 'Baking Goods', 'Canned', 'Health and Hygiene', 'Meat', 'Soft Drinks'
112
+ }