Bachstelze commited on
Commit
1070098
·
1 Parent(s): 45180b0

add missing imports

Browse files
Files changed (1) hide show
  1. A5/CorrelationFilter.py +5 -1
A5/CorrelationFilter.py CHANGED
@@ -1,4 +1,7 @@
1
  from sklearn.base import BaseEstimator, TransformerMixin
 
 
 
2
 
3
  # Finds similar features that are highly correlated and remove it
4
  class CorrelationFilter(BaseEstimator, TransformerMixin):
@@ -8,7 +11,8 @@ class CorrelationFilter(BaseEstimator, TransformerMixin):
8
 
9
  def fit(self, X, y=None):
10
  Xdf = pd.DataFrame(X) if not isinstance(X, pd.DataFrame) else X
11
- # calculates the correlation matrix and takes absolutte values since negative values are also calculated
 
12
  corr = Xdf.corr(numeric_only=True).abs()
13
  upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
14
  to_drop = [col for col in upper.columns if any(upper[col] >= self.threshold)]
 
1
  from sklearn.base import BaseEstimator, TransformerMixin
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
 
6
  # Finds similar features that are highly correlated and remove it
7
  class CorrelationFilter(BaseEstimator, TransformerMixin):
 
11
 
12
  def fit(self, X, y=None):
13
  Xdf = pd.DataFrame(X) if not isinstance(X, pd.DataFrame) else X
14
+ # calculates the correlation matrix and takes absolutte values
15
+ # since negative values are also calculated
16
  corr = Xdf.corr(numeric_only=True).abs()
17
  upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
18
  to_drop = [col for col in upper.columns if any(upper[col] >= self.threshold)]