AlgoX commited on
Commit
2d00e44
·
1 Parent(s): 9c57596

feat : add function to remove all NaN values

Browse files
Files changed (1) hide show
  1. data_prep/data_clean.py +73 -0
data_prep/data_clean.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ from typing import List
5
+
6
+ def clean_indicator(
7
+ df : pd.DataFrame,
8
+ feature_cols : List[str] | None = None,
9
+ drop_col_frac_threshold : float = 0.2
10
+ ) -> pd.DataFrame:
11
+
12
+ """
13
+ function to clean dataframe (remove all NaN values without lookahead)
14
+ args:
15
+ df : pandas dataframe
16
+ feature_cols : list of colms to treat as feats. Default -> None => use all colmns
17
+ drop_col_frac_threshold : drop the colmns with total NaN values greater than this
18
+ """
19
+
20
+ df = df.copy()
21
+
22
+ if feature_cols is None:
23
+
24
+ exclude = {
25
+ "Adj Close",
26
+ "Dividends",
27
+ "Stock Splits",
28
+ } # keep Close/Open/High/Low/Volume
29
+
30
+ feature_cols = [c for c in df.columns if c not in exclude]
31
+
32
+ #compute first valid and last valid positions
33
+ n = len(df)
34
+ first_positions = {}
35
+ last_positions = {}
36
+
37
+ for c in feature_cols:
38
+ fv = df[c].first_valid_index()
39
+ lv = df[c].last_valid_index()
40
+ first_positions[c] = df.index.get_loc(fv) if fv is not None else n
41
+ last_positions[c] = df.index.get_loc(lv) if lv is not None else -1
42
+
43
+ #triming window logic -> remove head warmi
44
+ #find features where all have values
45
+ start_pos = max(
46
+ first_positions.values()
47
+ )
48
+ #find last position where all features have values
49
+ end_pos = min(last_positions.values())
50
+
51
+ if start_pos >= end_pos:
52
+ # not enough overlap: as fallback, choose start = median of first positions, end = max of last positions
53
+ start_pos = int(np.median(list(first_positions.values())))
54
+ end_pos = int(np.median([pos for pos in last_positions.values() if pos >= 0]))
55
+
56
+ df_trim = df.iloc[start_pos : (end_pos + 1)].copy()
57
+
58
+ frac_nans = df_trim.isna().mean()
59
+ drop_cols = frac_nans[frac_nans > drop_col_frac_threshold].index.tolist()
60
+ # don't drop imp price columns
61
+ essential = {"Open", "High", "Low", "Close", "Volume"}
62
+
63
+ drop_cols = [c for c in drop_cols if c not in essential]
64
+
65
+ df_trim = df_trim.drop(columns=drop_cols)
66
+
67
+ df_imputed = df_trim.fillna(method="ffill") # type: ignore
68
+
69
+ medians = df_imputed.median()
70
+ df_imputed = df_imputed.fillna(medians)
71
+
72
+
73
+ return df_imputed