ArabovMK commited on
Commit
bd3c428
·
1 Parent(s): f716d2c

Update all files

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ .venv
3
+ __pycache__/
4
+ *__pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ .Python
9
+ streamlit_results/
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "streamlit/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,20 +1,268 @@
1
  ---
2
- title: TimeFlowPro
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
  sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: TimeFlowPro
12
- license: mit
13
  ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TimeFlow Pro
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ pinned: true
8
+ app_file: app.py
9
+ sdk_version: 1.52.2
 
 
 
10
  ---
11
 
12
+ # 📊 TimeFlow Pro
13
 
14
+ <div align="center">
15
 
16
+ **Intelligent Time Series Data Analysis and Preprocessing Platform**
17
+
18
+ *Advanced pipeline for data preparation and feature engineering*
19
+
20
+ [![Hugging Face](https://img.shields.io/badge/🤗-Hugging%20Face%20Space-blue)](https://huggingface.co/spaces/your-username/timeflow-pro)
21
+ [![Streamlit](https://img.shields.io/badge/Interface-Streamlit-FF4B4B)](https://streamlit.io)
22
+ [![Python](https://img.shields.io/badge/Python-3.9+-blue)](https://python.org)
23
+
24
+ </div>
25
+
26
+ ## 🌟 Overview
27
+
28
+ TimeFlow Pro is a comprehensive platform for time series data analysis, preprocessing, and feature engineering. Designed for data scientists and analysts, it provides an intuitive interface for transforming raw time series data into ML-ready datasets with advanced preprocessing capabilities.
29
+
30
+ ## 🚀 Key Features
31
+
32
+ ### 📈 **Data Analysis & Visualization**
33
+ - **Interactive Data Exploration**: Real-time preview and statistics
34
+ - **Missing Value Analysis**: Smart detection and handling strategies
35
+ - **Outlier Detection**: Multiple methods including IQR, Z-Score, Isolation Forest
36
+ - **Temporal Analysis**: Seasonality detection, trend analysis, decomposition
37
+
38
+ ### ⚙️ **Advanced Preprocessing Pipeline**
39
+ - **Feature Engineering**: Automatic lag features, rolling statistics, seasonal components
40
+ - **Stationarity Checking**: ADF tests and transformation suggestions
41
+ - **Data Scaling**: Robust, Standard, MinMax, and custom scaling methods
42
+ - **Feature Selection**: Correlation, variance, mutual information, RF importance
43
+
44
+ ### 🏗️ **ML-Ready Outputs**
45
+ - **Train/Validation/Test Splits**: Time-based or random splitting
46
+ - **Multiple Export Formats**: CSV, Parquet, Excel, JSON
47
+ - **Model Integration**: Ready-to-use datasets for scikit-learn, XGBoost, LightGBM
48
+ - **Visual Reports**: Comprehensive pipeline execution reports
49
+
50
+ ## 🎮 Quick Start
51
+
52
+ ### 1. **Upload Your Data**
53
+ - Support for CSV, Excel, Parquet formats
54
+ - Automatic date parsing and validation
55
+ - Smart column type detection
56
+
57
+ ### 2. **Configure Pipeline**
58
+ ```python
59
+ # Example configuration
60
+ config = {
61
+ 'target_column': 'sales',
62
+ 'test_size': 0.2,
63
+ 'max_lags': 5,
64
+ 'seasonal_period': 365,
65
+ 'scaling_method': 'robust'
66
+ }
67
+ ```
68
+
69
+ ### 3. **Run Pipeline & Export**
70
+ - Execute full preprocessing pipeline
71
+ - Download processed data
72
+ - Get feature importance reports
73
+ - Export modeling datasets
74
+
75
+ ## 📊 Technical Architecture
76
+
77
+ ### 🔧 **Pipeline Components**
78
+ ```
79
+ Data Loading → Validation → Missing Handling → Outlier Treatment
80
+
81
+ Feature Engineering → Stationarity Check → Correlation Analysis
82
+
83
+ Data Splitting → Scaling → Feature Selection → Final Validation
84
+ ```
85
+
86
+ ### 🏆 **Core Features**
87
+ - **Multi-stage Validation**: Raw, processed, and final data validation
88
+ - **Memory Optimization**: Efficient handling of large datasets
89
+ - **Error Recovery**: Graceful handling of pipeline failures
90
+ - **Reproducible Results**: Configuration saving and logging
91
+
92
+ ## 📚 Use Cases
93
+
94
+ ### 🏢 **Business Analytics**
95
+ - Sales forecasting and trend analysis
96
+ - Inventory optimization
97
+ - Customer behavior prediction
98
+ - Financial time series analysis
99
+
100
+ ### 🏭 **Industrial Applications**
101
+ - Sensor data preprocessing
102
+ - Predictive maintenance
103
+ - Quality control monitoring
104
+ - Energy consumption forecasting
105
+
106
+ ### 🎓 **Academic Research**
107
+ - Time series modeling experiments
108
+ - Feature engineering research
109
+ - Algorithm comparison studies
110
+ - Educational tool for data science
111
+
112
+ ## 🛠️ Installation
113
+
114
+ ### Local Development
115
+ ```bash
116
+ # Clone repository
117
+ git clone https://huggingface.co/spaces/your-username/timeflow-pro
118
+ cd timeflow-pro
119
+
120
+ # Install dependencies
121
+ pip install -r requirements.txt
122
+
123
+ # Run application
124
+ streamlit run app.py
125
+ ```
126
+
127
+ ### Docker Deployment
128
+ ```bash
129
+ # Build Docker image
130
+ docker build -t timeflow-pro .
131
+
132
+ # Run container
133
+ docker run -p 8501:8501 timeflow-pro
134
+ ```
135
+
136
+ ## 🌐 API Usage Example
137
+
138
+ ```python
139
+ from timeflow_pro import TimeFlowPipeline
140
+ import pandas as pd
141
+
142
+ # Load your data
143
+ data = pd.read_csv('your_data.csv')
144
+
145
+ # Configure pipeline
146
+ config = {
147
+ 'target_column': 'target',
148
+ 'test_size': 0.2,
149
+ 'max_lags': 7,
150
+ 'seasonal_period': 30
151
+ }
152
+
153
+ # Create and run pipeline
154
+ pipeline = TimeFlowPipeline(config)
155
+ processed_data = pipeline.run(data)
156
+
157
+ # Get modeling data
158
+ modeling_data = pipeline.get_modeling_data()
159
+ X_train, y_train = modeling_data['X_train'], modeling_data['y_train']
160
+ ```
161
+
162
+ ## 📈 Performance Benchmarks
163
+
164
+ | Dataset Size | Processing Time | Memory Usage | Features Generated |
165
+ |--------------|----------------|--------------|-------------------|
166
+ | 10K rows | ~5 seconds | <500 MB | 50-100 features |
167
+ | 100K rows | ~30 seconds | <1 GB | 100-200 features |
168
+ | 1M rows | ~5 minutes | <2 GB | 200-500 features |
169
+
170
+ ## 🔧 Configuration Options
171
+
172
+ ### **Data Processing**
173
+ - `missing_threshold`: Threshold for column removal (0.0-0.5)
174
+ - `outlier_method`: IQR, Z-Score, or Isolation Forest
175
+ - `scaling_method`: Robust, Standard, MinMax, or None
176
+
177
+ ### **Feature Engineering**
178
+ - `max_lags`: Maximum lag features (1-20)
179
+ - `seasonal_period`: Seasonal window (7, 30, 90, 365)
180
+ - `rolling_windows`: List of rolling windows [7, 30, 90]
181
+
182
+ ### **Model Preparation**
183
+ - `feature_selection_method`: Correlation, Variance, RF, Mutual Info
184
+ - `max_features`: Maximum features to select (5-100)
185
+ - `split_method`: Time-based or random splitting
186
+
187
+ ## 📋 Requirements
188
+
189
+ ### **Core Dependencies**
190
+ ```txt
191
+ streamlit>=1.28.0
192
+ pandas>=2.0.0
193
+ numpy>=1.24.0
194
+ plotly>=5.17.0
195
+ scikit-learn>=1.3.0
196
+ ```
197
+
198
+ ### **Optional Dependencies**
199
+ ```txt
200
+ xgboost>=2.0.0 # For XGBoost feature importance
201
+ lightgbm>=4.0.0 # For LightGBM integration
202
+ statsmodels>=0.14.0 # For advanced time series analysis
203
+ ```
204
+
205
+ ## 🤝 Contributing
206
+
207
+ We welcome contributions! Here's how you can help:
208
+
209
+ ### **Areas for Contribution**
210
+ 1. **New Feature Engineering Methods**
211
+ 2. **Additional Visualization Types**
212
+ 3. **Export Format Support**
213
+ 4. **Performance Optimizations**
214
+ 5. **Documentation Improvements**
215
+
216
+ ### **Development Workflow**
217
+ ```bash
218
+ # 1. Fork the repository
219
+ # 2. Create feature branch
220
+ git checkout -b feature/new-feature
221
+
222
+ # 3. Make changes and test
223
+ # 4. Submit pull request
224
+ ```
225
+
226
+ ## 📜 License
227
+
228
+ This project is licensed under the **MIT License** - see the [LICENSE](LICENSE) file for details.
229
+
230
+ ## 🙏 Acknowledgments
231
+
232
+ ### **Special Thanks To:**
233
+ - **Streamlit Team** for the amazing framework
234
+ - **Hugging Face** for hosting the Space
235
+ - **Open Source Community** for invaluable libraries
236
+ - **All Contributors** who helped improve TimeFlow Pro
237
+
238
+ ### **Built With:**
239
+ - 🐍 Python
240
+ - 📊 Streamlit
241
+ - 🎨 Plotly
242
+ - 🔧 Scikit-learn
243
+ - 📈 Pandas & NumPy
244
+
245
+ ## 📞 Support & Contact
246
+
247
+ ### **Get Help:**
248
+ - 📧 **Email**: cool.araby@gmail.com
249
+ - 💬 **Issues**: [GitHub Issues](https://github.com/your-username/timeflow-pro/issues)
250
+ - 💡 **Discussions**: [Community Forum](https://github.com/your-username/timeflow-pro/discussions)
251
+
252
+ ### **Stay Updated:**
253
+ - ⭐ **Star** the repository
254
+ - 👁️ **Watch** for releases
255
+ - 🔔 **Enable notifications**
256
+
257
+ ---
258
+
259
+ <div align="center">
260
+
261
+ **Transform Your Time Series Data with Ease**
262
+
263
+ *TimeFlow Pro - Making Data Preparation Simple and Powerful*
264
+
265
+ [![Follow on Hugging Face](https://img.shields.io/badge/Follow%20on-🤗%20Hugging%20Face-yellow)](https://huggingface.co/your-username)
266
+ [![GitHub Stars](https://img.shields.io/github/stars/your-username/timeflow-pro?style=social)](https://github.com/your-username/timeflow-pro)
267
+
268
+ </div>
app.py ADDED
The diff for this file is too large to render. See raw diff
 
config/__init__.py ADDED
File without changes
config/config.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # ENUMERATION CLASSES
3
+ # ============================================
4
+ from dataclasses import asdict, dataclass, field
5
+ from enum import Enum
6
+ import json
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional
10
+ from venv import logger
11
+
12
+
13
+ class DataType(Enum):
14
+ """Data types"""
15
+ NUMERIC = "numeric"
16
+ CATEGORICAL = "categorical"
17
+ TEMPORAL = "temporal"
18
+ TEXT = "text"
19
+
20
+
21
+ class PreprocessingMethod(Enum):
22
+ """Data preprocessing methods"""
23
+ FILL_MEAN = "fill_mean"
24
+ FILL_MEDIAN = "fill_median"
25
+ FILL_INTERPOLATE = "fill_interpolate"
26
+ FILL_KNN = "fill_knn"
27
+ REMOVE = "remove"
28
+ CLIP = "clip"
29
+ WINSORIZE = "winsorize"
30
+ NORMALIZE = "normalize"
31
+ STANDARDIZE = "standardize"
32
+ LOG_TRANSFORM = "log_transform"
33
+ BOX_COX = "box_cox"
34
+ DIFFERENCING = "differencing"
35
+
36
+
37
+ class SeasonalityType(Enum):
38
+ """Seasonality types"""
39
+ DAILY = "daily"
40
+ WEEKLY = "weekly"
41
+ MONTHLY = "monthly"
42
+ QUARTERLY = "quarterly"
43
+ YEARLY = "yearly"
44
+ MULTIPLE = "multiple"
45
+
46
+
47
+ # ============================================
48
+ # CLASS 1: CONFIGURATION
49
+ # ============================================
50
+ @dataclass
51
+ class Config:
52
+ """Experiment configuration for data preprocessing"""
53
+
54
+ # Paths and directories
55
+ data_path: str = 'temp_data.csv'
56
+ results_dir: str = 'data_preprocessing_results'
57
+
58
+ # Temporal parameters
59
+ start_year: int = 1970
60
+ end_year: int = 1990
61
+ freq: str = 'D' # Data frequency: D (daily), H (hourly), M (monthly)
62
+
63
+ # Target variable
64
+ target_column: str = 'raskhodvoda'
65
+
66
+ # Feature parameters
67
+ max_lags: int = 12
68
+ seasonal_period: int = 365
69
+ rolling_windows: List[int] = field(default_factory=lambda: [7, 30, 90, 365])
70
+ expanding_windows: List[int] = field(default_factory=lambda: [30, 90, 365])
71
+
72
+ # Processing parameters
73
+ missing_threshold: float = 0.3 # Threshold for dropping columns with missing values
74
+ outlier_method: str = 'iqr' # Outlier detection method: iqr, zscore, lof
75
+ outlier_alpha: float = 1.5 # IQR multiplier
76
+ outlier_contamination: float = 0.1 # For methods like LOF
77
+
78
+ # Data splitting
79
+ test_size: float = 0.2
80
+ validation_size: float = 0.1
81
+ split_method: str = 'time' # time, random, expanding_window
82
+
83
+ # Scaling
84
+ scaling_method: str = 'robust' # standard, minmax, robust, none
85
+
86
+ # Feature selection
87
+ feature_selection_method: str = 'correlation' # correlation, mutual_info, rf, pca
88
+ max_features: int = 50
89
+
90
+ # Validation
91
+ enable_validation: bool = True
92
+ validation_rules: Dict = field(default_factory=dict)
93
+
94
+ # Visualisation
95
+ save_plots: bool = True
96
+ plot_style: str = 'seaborn'
97
+
98
+ # Performance
99
+ use_multiprocessing: bool = False
100
+ n_jobs: int = -1
101
+ chunk_size: int = 10000
102
+
103
+ # Logging
104
+ log_level: str = 'INFO'
105
+ save_reports: bool = True
106
+
107
+ def __post_init__(self):
108
+ """Post-initialisation for creating directories and setting up logging"""
109
+ self.create_directories()
110
+ self.setup_logging()
111
+
112
+ # Setting default validation rules
113
+ if not self.validation_rules:
114
+ self.validation_rules = {
115
+ 'min_rows': 100,
116
+ 'max_missing_percentage': 30,
117
+ 'min_unique_values': 2,
118
+ 'max_skewness': 3,
119
+ 'max_kurtosis': 10
120
+ }
121
+
122
+ def create_directories(self) -> None:
123
+ """Create directories for preprocessing results"""
124
+ dirs = [
125
+ self.results_dir,
126
+ f'{self.results_dir}/plots',
127
+ f'{self.results_dir}/plots/time_series',
128
+ f'{self.results_dir}/plots/distributions',
129
+ f'{self.results_dir}/plots/correlations',
130
+ f'{self.results_dir}/plots/features',
131
+ f'{self.results_dir}/tables',
132
+ f'{self.results_dir}/processed_data',
133
+ f'{self.results_dir}/models',
134
+ f'{self.results_dir}/reports',
135
+ f'{self.results_dir}/logs',
136
+ f'{self.results_dir}/checkpoints'
137
+ ]
138
+
139
+ for directory in dirs:
140
+ Path(directory).mkdir(parents=True, exist_ok=True)
141
+
142
+ logger.info(f"Directories created in {self.results_dir}")
143
+
144
+ def setup_logging(self) -> None:
145
+ """Configure logging"""
146
+ log_level = getattr(logging, self.log_level.upper())
147
+ logger.setLevel(log_level)
148
+
149
+ def to_dict(self) -> Dict:
150
+ """Convert configuration to dictionary"""
151
+ return asdict(self)
152
+
153
+ def save(self, path: Optional[str] = None) -> None:
154
+ """Save configuration to file"""
155
+ if path is None:
156
+ path = f'{self.results_dir}/config.json'
157
+
158
+ with open(path, 'w', encoding='utf-8') as f:
159
+ json.dump(self.to_dict(), f, indent=4, ensure_ascii=False)
160
+
161
+ logger.info(f"Configuration saved to {path}")
162
+
163
+ @classmethod
164
+ def load(cls, path: str) -> 'Config':
165
+ """Load configuration from file"""
166
+ with open(path, 'r', encoding='utf-8') as f:
167
+ config_dict = json.load(f)
168
+
169
+ return cls(**config_dict)
config/default_config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data_path": "temp_data.csv",
3
+ "results_dir": "results",
4
+
5
+ "start_year": 1970,
6
+ "end_year": 1990,
7
+ "freq": "D",
8
+
9
+ "target_column": "raskhodvoda",
10
+
11
+ "max_lags": 12,
12
+ "seasonal_period": 365,
13
+ "rolling_windows": [7, 30, 90, 365],
14
+ "expanding_windows": [30, 90, 365],
15
+
16
+ "missing_threshold": 0.3,
17
+ "outlier_method": "iqr",
18
+ "outlier_alpha": 1.5,
19
+ "outlier_contamination": 0.1,
20
+
21
+ "test_size": 0.2,
22
+ "validation_size": 0.1,
23
+ "split_method": "time",
24
+
25
+ "scaling_method": "robust",
26
+
27
+ "feature_selection_method": "correlation",
28
+ "max_features": 50,
29
+
30
+ "enable_validation": true,
31
+ "validation_rules": {
32
+ "min_rows": 100,
33
+ "max_missing_percentage": 30,
34
+ "min_unique_values": 2,
35
+ "max_skewness": 3,
36
+ "max_kurtosis": 10,
37
+ "min_variance": 0.001,
38
+ "max_constant_columns": 0
39
+ },
40
+
41
+ "save_plots": true,
42
+ "plot_style": "seaborn-whitegrid",
43
+ "plot_dpi": 300,
44
+ "plot_format": "png",
45
+
46
+ "use_multiprocessing": false,
47
+ "n_jobs": -1,
48
+ "chunk_size": 10000,
49
+ "memory_limit_gb": 4,
50
+
51
+ "log_level": "INFO",
52
+ "save_reports": true,
53
+ "report_format": "json",
54
+
55
+ "decomposition_method": "stl",
56
+ "stationarity_tests": ["adf", "kpss"],
57
+ "correlation_threshold": 0.85,
58
+ "vif_threshold": 10,
59
+
60
+ "random_seed": 42,
61
+ "enable_profiling": false,
62
+ "save_intermediate": true,
63
+
64
+ "streamlit_settings": {
65
+ "theme": "light",
66
+ "sidebar_state": "expanded",
67
+ "page_title": "Time Series Preprocessing",
68
+ "page_icon": "📊",
69
+ "layout": "wide"
70
+ },
71
+
72
+ "export_options": {
73
+ "csv": true,
74
+ "parquet": false,
75
+ "excel": false,
76
+ "pickle": true
77
+ }
78
+ }
config/settings.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ General project settings: visualisation, paths, constants
3
+ """
4
+
5
+ import warnings
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional
10
+ import yaml
11
+ import json
12
+ import os
13
+
14
+ # ============================================================================
15
+ # PATHS AND DIRECTORIES
16
+ # ============================================================================
17
+
18
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
19
+ DATA_DIR = PROJECT_ROOT / "data"
20
+ RAW_DATA_DIR = DATA_DIR / "raw"
21
+ PROCESSED_DATA_DIR = DATA_DIR / "processed"
22
+ EXTERNAL_DATA_DIR = DATA_DIR / "external"
23
+
24
+ RESULTS_DIR = PROJECT_ROOT / "results"
25
+ PLOTS_DIR = RESULTS_DIR / "plots"
26
+ MODELS_DIR = RESULTS_DIR / "models"
27
+ REPORTS_DIR = RESULTS_DIR / "reports"
28
+ LOGS_DIR = RESULTS_DIR / "logs"
29
+
30
+ CONFIGS_DIR = PROJECT_ROOT / "configs"
31
+ NOTEBOOKS_DIR = PROJECT_ROOT / "notebooks"
32
+ TESTS_DIR = PROJECT_ROOT / "tests"
33
+
34
+ # Create directories on import
35
+ for directory in [RAW_DATA_DIR, PROCESSED_DATA_DIR, EXTERNAL_DATA_DIR,
36
+ PLOTS_DIR, MODELS_DIR, REPORTS_DIR, LOGS_DIR]:
37
+ directory.mkdir(parents=True, exist_ok=True)
38
+
39
+ # ============================================================================
40
+ # VISUALISATION SETTINGS
41
+ # ============================================================================
42
+
43
+ def setup_visualization(
44
+ style: str = "seaborn-whitegrid",
45
+ palette: str = "husl",
46
+ context: str = "notebook",
47
+ font_scale: float = 1.0,
48
+ dpi: int = 150,
49
+ figsize: tuple = (12, 6),
50
+ **kwargs
51
+ ):
52
+ """
53
+ Configure visualisation parameters for matplotlib and seaborn
54
+
55
+ Parameters:
56
+ -----------
57
+ style : str
58
+ Matplotlib style: 'seaborn-whitegrid', 'ggplot', 'bmh', 'dark_background'
59
+ palette : str
60
+ Seaborn palette: 'husl', 'Set2', 'viridis', 'mako'
61
+ context : str
62
+ Seaborn context: 'paper', 'notebook', 'talk', 'poster'
63
+ font_scale : float
64
+ Font scale
65
+ dpi : int
66
+ Plot resolution
67
+ figsize : tuple
68
+ Default figure size
69
+ """
70
+ # Ignore warnings
71
+ warnings.filterwarnings('ignore')
72
+
73
+ # Matplotlib settings
74
+ plt.style.use(style)
75
+
76
+ # RC parameters
77
+ rc_params = {
78
+ 'font.size': 10,
79
+ 'figure.figsize': figsize,
80
+ 'figure.dpi': dpi,
81
+ 'savefig.dpi': 300,
82
+ 'savefig.bbox': 'tight',
83
+ 'savefig.format': 'png',
84
+ 'axes.titlesize': 12,
85
+ 'axes.labelsize': 10,
86
+ 'xtick.labelsize': 9,
87
+ 'ytick.labelsize': 9,
88
+ 'legend.fontsize': 9,
89
+ 'font.family': ['DejaVu Sans', 'Arial', 'sans-serif'],
90
+ 'figure.titlesize': 14,
91
+ 'axes.grid': True,
92
+ 'grid.alpha': 0.3,
93
+ 'lines.linewidth': 1.5,
94
+ 'lines.markersize': 6,
95
+ 'patch.edgecolor': 'black',
96
+ 'patch.force_edgecolor': True,
97
+ 'xtick.top': False,
98
+ 'ytick.right': False,
99
+ 'axes.spines.top': False,
100
+ 'axes.spines.right': False
101
+ }
102
+
103
+ # Update additional parameters
104
+ rc_params.update(kwargs)
105
+ plt.rcParams.update(rc_params)
106
+
107
+ # Seaborn settings
108
+ sns.set_style(style.replace('seaborn-', ''))
109
+ sns.set_palette(palette)
110
+ sns.set_context(context, font_scale=font_scale)
111
+
112
+ print(f"✓ Visualisation settings applied: style={style}, palette={palette}")
113
+
114
+
115
+ def get_color_palette(name: str = "husl", n_colors: int = 8) -> list:
116
+ """
117
+ Get colour palette
118
+
119
+ Parameters:
120
+ -----------
121
+ name : str
122
+ Palette name
123
+ n_colors : int
124
+ Number of colours
125
+
126
+ Returns:
127
+ --------
128
+ list
129
+ List of colours in HEX format
130
+ """
131
+ palette_map = {
132
+ "husl": sns.color_palette("husl", n_colors),
133
+ "Set2": sns.color_palette("Set2", n_colors),
134
+ "Set3": sns.color_palette("Set3", n_colors),
135
+ "viridis": sns.color_palette("viridis", n_colors),
136
+ "plasma": sns.color_palette("plasma", n_colors),
137
+ "coolwarm": sns.color_palette("coolwarm", n_colors),
138
+ "RdYlBu": sns.color_palette("RdYlBu", n_colors),
139
+ "Spectral": sns.color_palette("Spectral", n_colors),
140
+ "tab10": sns.color_palette("tab10", n_colors),
141
+ "tab20": sns.color_palette("tab20", n_colors),
142
+ }
143
+
144
+ palette = palette_map.get(name, sns.color_palette("husl", n_colors))
145
+ return [f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
146
+ for r, g, b in palette]
147
+
148
+
149
+ # ============================================================================
150
+ # CONSTANTS
151
+ # ============================================================================
152
+
153
+ # Data types
154
+ DATETIME_FORMATS = [
155
+ "%Y-%m-%d", "%Y/%m/%d", "%d.%m.%Y", "%d/%m/%Y",
156
+ "%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S",
157
+ "%d.%m.%Y %H:%M:%S", "%d/%m/%Y %H:%M:%S"
158
+ ]
159
+
160
+ # Metrics
161
+ METRICS = {
162
+ "regression": ["mse", "rmse", "mae", "mape", "r2", "explained_variance"],
163
+ "classification": ["accuracy", "precision", "recall", "f1", "roc_auc"]
164
+ }
165
+
166
+ # Statistical constants
167
+ STATS_CONSTANTS = {
168
+ "confidence_levels": [0.9, 0.95, 0.99],
169
+ "z_scores": {0.9: 1.645, 0.95: 1.96, 0.99: 2.576},
170
+ "outlier_multipliers": {"mild": 1.5, "extreme": 3.0}
171
+ }
172
+
173
+ # Time series parameters
174
+ TIME_SERIES_CONSTANTS = {
175
+ "frequencies": {
176
+ "H": "hourly",
177
+ "D": "daily",
178
+ "W": "weekly",
179
+ "M": "monthly",
180
+ "Q": "quarterly",
181
+ "Y": "yearly"
182
+ },
183
+ "seasonal_periods": {
184
+ "hourly": 24,
185
+ "daily": 7,
186
+ "weekly": 52,
187
+ "monthly": 12,
188
+ "quarterly": 4,
189
+ "yearly": 1
190
+ }
191
+ }
192
+
193
+ # ============================================================================
194
+ # CONFIGURATION UTILITIES
195
+ # ============================================================================
196
+
197
+ def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
198
+ """
199
+ Load configuration from file
200
+
201
+ Parameters:
202
+ -----------
203
+ config_path : str, optional
204
+ Path to configuration file
205
+
206
+ Returns:
207
+ --------
208
+ Dict[str, Any]
209
+ Configuration dictionary
210
+ """
211
+ if config_path is None:
212
+ config_path = CONFIGS_DIR / "default_config.json"
213
+
214
+ config_path = Path(config_path)
215
+
216
+ if not config_path.exists():
217
+ print(f"⚠ Configuration file not found: {config_path}")
218
+ return {}
219
+
220
+ # Determine file format
221
+ if config_path.suffix.lower() in ['.json']:
222
+ with open(config_path, 'r', encoding='utf-8') as f:
223
+ config = json.load(f)
224
+ elif config_path.suffix.lower() in ['.yaml', '.yml']:
225
+ with open(config_path, 'r', encoding='utf-8') as f:
226
+ config = yaml.safe_load(f)
227
+ else:
228
+ raise ValueError(f"Unsupported file format: {config_path.suffix}")
229
+
230
+ print(f"✓ Configuration loaded from: {config_path}")
231
+ return config
232
+
233
+
234
+ def save_config(config: Dict[str, Any], config_path: str) -> None:
235
+ """
236
+ Save configuration to file
237
+
238
+ Parameters:
239
+ -----------
240
+ config : Dict[str, Any]
241
+ Configuration to save
242
+ config_path : str
243
+ Save path
244
+ """
245
+ config_path = Path(config_path)
246
+ config_path.parent.mkdir(parents=True, exist_ok=True)
247
+
248
+ # Determine format
249
+ if config_path.suffix.lower() in ['.json']:
250
+ with open(config_path, 'w', encoding='utf-8') as f:
251
+ json.dump(config, f, indent=2, ensure_ascii=False)
252
+ elif config_path.suffix.lower() in ['.yaml', '.yml']:
253
+ with open(config_path, 'w', encoding='utf-8') as f:
254
+ yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
255
+ else:
256
+ raise ValueError(f"Unsupported file format: {config_path.suffix}")
257
+
258
+ print(f"✓ Configuration saved to: {config_path}")
259
+
260
+
261
+ def merge_configs(base_config: Dict[str, Any],
262
+ override_config: Dict[str, Any]) -> Dict[str, Any]:
263
+ """
264
+ Recursive configuration merging
265
+
266
+ Parameters:
267
+ -----------
268
+ base_config : Dict[str, Any]
269
+ Base configuration
270
+ override_config : Dict[str, Any]
271
+ Override configuration
272
+
273
+ Returns:
274
+ --------
275
+ Dict[str, Any]
276
+ Merged configuration
277
+ """
278
+ result = base_config.copy()
279
+
280
+ for key, value in override_config.items():
281
+ if (key in result and isinstance(result[key], dict)
282
+ and isinstance(value, dict)):
283
+ result[key] = merge_configs(result[key], value)
284
+ else:
285
+ result[key] = value
286
+
287
+ return result
288
+
289
+
290
+ # ============================================================================
291
+ # ENVIRONMENT SETUP
292
+ # ============================================================================
293
+
294
+ def setup_environment(
295
+ log_level: str = "INFO",
296
+ random_seed: int = 42,
297
+ enable_warnings: bool = False,
298
+ memory_limit_gb: Optional[int] = None
299
+ ) -> None:
300
+ """
301
+ Set up environment for reproducibility
302
+
303
+ Parameters:
304
+ -----------
305
+ log_level : str
306
+ Logging level
307
+ random_seed : int
308
+ Seed for random generators
309
+ enable_warnings : bool
310
+ Enable warnings
311
+ memory_limit_gb : int, optional
312
+ Memory limit in GB
313
+ """
314
+ import numpy as np
315
+ import random
316
+ import torch
317
+ import tensorflow as tf
318
+
319
+ # Set seeds
320
+ np.random.seed(random_seed)
321
+ random.seed(random_seed)
322
+
323
+ try:
324
+ torch.manual_seed(random_seed)
325
+ except:
326
+ pass
327
+
328
+ try:
329
+ tf.random.set_seed(random_seed)
330
+ except:
331
+ pass
332
+
333
+ # Configure warnings
334
+ if enable_warnings:
335
+ warnings.filterwarnings('default')
336
+ else:
337
+ warnings.filterwarnings('ignore')
338
+
339
+ # Memory limit (if specified)
340
+ if memory_limit_gb:
341
+ import resource
342
+ soft, hard = resource.getrlimit(resource.RLIMIT_AS)
343
+ memory_limit = memory_limit_gb * 1024**3 # GB to bytes
344
+ resource.setrlimit(resource.RLIMIT_AS, (memory_limit, hard))
345
+ print(f"✓ Memory limit set: {memory_limit_gb} GB")
346
+
347
+ print(f"✓ Environment configured. Random seed: {random_seed}")
348
+
349
+
350
+ # ============================================================================
351
+ # AUTOMATIC SETUP ON IMPORT
352
+ # ============================================================================
353
+
354
+ # Automatically apply visualisation settings
355
+ setup_visualization()
356
+
357
+ # Export useful variables
358
+ __all__ = [
359
+ 'setup_visualization',
360
+ 'get_color_palette',
361
+ 'load_config',
362
+ 'save_config',
363
+ 'merge_configs',
364
+ 'setup_environment',
365
+ 'PROJECT_ROOT',
366
+ 'DATA_DIR',
367
+ 'RAW_DATA_DIR',
368
+ 'PROCESSED_DATA_DIR',
369
+ 'RESULTS_DIR',
370
+ 'PLOTS_DIR',
371
+ 'DATETIME_FORMATS',
372
+ 'METRICS',
373
+ 'STATS_CONSTANTS',
374
+ 'TIME_SERIES_CONSTANTS'
375
+ ]
correlations/__init__.py ADDED
File without changes
correlations/correlation_analyzer.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 8: CORRELATION AND MULTICOLLINEARITY ANALYSIS
3
+ # ============================================
4
+ import os
5
+ import traceback
6
+ from typing import Any, Dict, List, Optional
7
+ from venv import logger
8
+
9
+ from config.config import Config
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ class CorrelationAnalyzer:
14
+ """Class for comprehensive correlation and multicollinearity analysis"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise the analyser
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.correlation_matrices = {}
27
+ self.high_correlation_pairs = {}
28
+ self.multicollinearity_info = {}
29
+ self.vif_scores = {}
30
+
31
+ def analyze(
32
+ self,
33
+ data: pd.DataFrame,
34
+ target_col: Optional[str] = None,
35
+ threshold: float = 0.8,
36
+ detailed: bool = True,
37
+ **kwargs
38
+ ) -> pd.DataFrame:
39
+ """
40
+ Analyse correlations in the data
41
+
42
+ Parameters:
43
+ -----------
44
+ data : pd.DataFrame
45
+ Input data
46
+ target_col : str, optional
47
+ Target variable
48
+ threshold : float
49
+ Threshold for identifying high correlations
50
+ detailed : bool
51
+ Whether to perform detailed analysis
52
+ **kwargs : dict
53
+ Additional parameters
54
+
55
+ Returns:
56
+ --------
57
+ pd.DataFrame
58
+ Correlation matrix
59
+ """
60
+ logger.info("\n" + "="*80)
61
+ logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS")
62
+ logger.info("="*80)
63
+
64
+ target_col = target_col or self.config.target_column
65
+
66
+ try:
67
+ # 1. Calculate correlation matrix
68
+ corr_matrix = self._compute_correlations(data, target_col)
69
+
70
+ if corr_matrix.empty:
71
+ logger.warning("Correlation matrix is empty")
72
+ return pd.DataFrame()
73
+
74
+ # 2. Identify high correlations
75
+ high_correlations = self._detect_high_correlations(corr_matrix, threshold)
76
+ self.high_correlation_pairs['pearson'] = high_correlations
77
+
78
+ # 3. Analyse correlations with target variable
79
+ target_correlations = []
80
+ if target_col in corr_matrix.columns:
81
+ target_correlations = self._get_target_correlations(corr_matrix, target_col)
82
+
83
+ # 4. Analyse multicollinearity (VIF)
84
+ vif_results = self._compute_vif_scores(data)
85
+
86
+ # 5. Detailed analysis if required
87
+ if detailed:
88
+ self._detailed_correlation_analysis(data, corr_matrix, target_col)
89
+
90
+ # 6. Visualisation
91
+ if self.config.save_plots:
92
+ self._plot_correlation_analysis(data, corr_matrix, target_col, high_correlations, vif_results)
93
+
94
+ # 7. Output results
95
+ self._log_analysis_results(corr_matrix, high_correlations, target_correlations, vif_results)
96
+
97
+ return corr_matrix
98
+
99
+ except Exception as e:
100
+ logger.error(f"Error in correlation analysis: {e}")
101
+ logger.error(traceback.format_exc())
102
+ return pd.DataFrame()
103
+
104
+ def _compute_correlations(
105
+ self,
106
+ data: pd.DataFrame,
107
+ target_col: str
108
+ ) -> pd.DataFrame:
109
+ """Calculate correlation matrix"""
110
+ logger.info("Calculating correlation matrix...")
111
+
112
+ # Select only numeric columns
113
+ numeric_data = data.select_dtypes(include=[np.number])
114
+
115
+ # Remove constant columns
116
+ numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
117
+
118
+ if numeric_data.shape[1] < 2:
119
+ logger.warning("Insufficient numeric features for analysis")
120
+ return pd.DataFrame()
121
+
122
+ # Remove missing values
123
+ numeric_data_clean = numeric_data.dropna()
124
+
125
+ if len(numeric_data_clean) < 10:
126
+ logger.warning("Insufficient data after cleaning")
127
+ return pd.DataFrame()
128
+
129
+ # Calculate Pearson correlation
130
+ try:
131
+ corr_matrix = numeric_data_clean.corr(method='pearson')
132
+ self.correlation_matrices['pearson'] = corr_matrix
133
+ logger.info(f"✓ Correlation matrix calculated: {corr_matrix.shape}")
134
+ return corr_matrix
135
+ except Exception as e:
136
+ logger.error(f"Error calculating correlation: {e}")
137
+ return pd.DataFrame()
138
+
139
+ def _detect_high_correlations(
140
+ self,
141
+ corr_matrix: pd.DataFrame,
142
+ threshold: float = 0.8
143
+ ) -> List[Dict[str, Any]]:
144
+ """Detect high correlations"""
145
+ high_correlations = []
146
+
147
+ if corr_matrix.empty:
148
+ return high_correlations
149
+
150
+ # Use upper triangle of matrix
151
+ upper_triangle = corr_matrix.where(
152
+ np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
153
+ )
154
+
155
+ # Find pairs with correlation above threshold
156
+ for col in upper_triangle.columns:
157
+ if col in upper_triangle:
158
+ high_corr_series = upper_triangle[col][abs(upper_triangle[col]) > threshold]
159
+
160
+ for row_idx, correlation in high_corr_series.items():
161
+ if not pd.isna(correlation):
162
+ high_correlations.append({
163
+ 'feature1': row_idx,
164
+ 'feature2': col,
165
+ 'correlation': float(correlation),
166
+ 'abs_correlation': abs(float(correlation))
167
+ })
168
+
169
+ # Sort by absolute correlation value
170
+ high_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
171
+
172
+ logger.info(f"High correlations detected (> {threshold}): {len(high_correlations)}")
173
+ return high_correlations
174
+
175
+ def _get_target_correlations(
176
+ self,
177
+ corr_matrix: pd.DataFrame,
178
+ target_col: str
179
+ ) -> List[Dict[str, Any]]:
180
+ """Get correlations with target variable"""
181
+ target_correlations = []
182
+
183
+ if target_col not in corr_matrix.columns:
184
+ return target_correlations
185
+
186
+ # Extract correlations with target variable
187
+ target_corr_series = corr_matrix[target_col]
188
+
189
+ for feature, correlation in target_corr_series.items():
190
+ if feature != target_col and not pd.isna(correlation):
191
+ target_correlations.append({
192
+ 'feature': feature,
193
+ 'correlation': float(correlation),
194
+ 'abs_correlation': abs(float(correlation)),
195
+ 'direction': 'positive' if correlation > 0 else 'negative'
196
+ })
197
+
198
+ # Sort by absolute value
199
+ target_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
200
+
201
+ logger.info(f"Correlations with target variable calculated: {len(target_correlations)}")
202
+ return target_correlations
203
+
204
+ def _compute_vif_scores(self, data: pd.DataFrame) -> Dict[str, Any]:
205
+ """Calculate VIF (Variance Inflation Factor)"""
206
+ logger.info("Analysing multicollinearity (VIF)...")
207
+
208
+ vif_results = {
209
+ 'scores': {},
210
+ 'issues': [],
211
+ 'summary': {
212
+ 'critical': 0,
213
+ 'high': 0,
214
+ 'medium': 0,
215
+ 'low': 0
216
+ }
217
+ }
218
+
219
+ try:
220
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
221
+ import statsmodels.api as sm
222
+
223
+ # Prepare data
224
+ numeric_data = data.select_dtypes(include=[np.number])
225
+ numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
226
+
227
+ # Remove missing and infinite values
228
+ clean_data = numeric_data.replace([np.inf, -np.inf], np.nan).dropna()
229
+
230
+ if clean_data.shape[0] < 10 or clean_data.shape[1] < 2:
231
+ logger.warning("Insufficient data for VIF analysis")
232
+ return vif_results
233
+
234
+ # Add constant
235
+ X = sm.add_constant(clean_data, has_constant='add')
236
+
237
+ # Calculate VIF for each feature
238
+ vif_scores = {}
239
+ for i, column in enumerate(X.columns):
240
+ if column == 'const':
241
+ continue
242
+
243
+ try:
244
+ vif = variance_inflation_factor(X.values, i)
245
+
246
+ # Handle extreme values
247
+ if np.isinf(vif) or vif > 1e6:
248
+ vif = 1e6
249
+
250
+ vif_scores[column] = float(vif)
251
+
252
+ # Classify by severity
253
+ if vif > 100:
254
+ vif_results['summary']['critical'] += 1
255
+ vif_results['issues'].append({
256
+ 'feature': column,
257
+ 'vif': float(vif),
258
+ 'severity': 'critical',
259
+ 'recommendation': 'Remove feature'
260
+ })
261
+ elif vif > 10:
262
+ vif_results['summary']['high'] += 1
263
+ vif_results['issues'].append({
264
+ 'feature': column,
265
+ 'vif': float(vif),
266
+ 'severity': 'high',
267
+ 'recommendation': 'Consider removal'
268
+ })
269
+ elif vif > 5:
270
+ vif_results['summary']['medium'] += 1
271
+ else:
272
+ vif_results['summary']['low'] += 1
273
+
274
+ except Exception as e:
275
+ logger.warning(f"VIF error for {column}: {e}")
276
+ vif_scores[column] = np.nan
277
+
278
+ vif_results['scores'] = vif_scores
279
+ self.vif_scores = vif_scores
280
+
281
+ logger.info(f"✓ VIF analysis completed. Critical features: {vif_results['summary']['critical']}")
282
+
283
+ except ImportError:
284
+ logger.warning("statsmodels not installed, skipping VIF analysis")
285
+ except Exception as e:
286
+ logger.error(f"VIF analysis error: {e}")
287
+
288
+ return vif_results
289
+
290
+ def _detailed_correlation_analysis(
291
+ self,
292
+ data: pd.DataFrame,
293
+ corr_matrix: pd.DataFrame,
294
+ target_col: str
295
+ ) -> None:
296
+ """Detailed correlation analysis"""
297
+ # Analyse correlation clusters
298
+ if not corr_matrix.empty and corr_matrix.shape[0] > 3:
299
+ try:
300
+ # Use clustering to group correlated features
301
+ from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
302
+ from scipy.spatial.distance import squareform
303
+
304
+ # Convert correlations to distances
305
+ distance_matrix = 1 - abs(corr_matrix)
306
+ np.fill_diagonal(distance_matrix.values, 0)
307
+
308
+ # Clustering
309
+ condensed_dist = squareform(distance_matrix)
310
+ Z = linkage(condensed_dist, method='average')
311
+
312
+ # Determine clusters
313
+ clusters = fcluster(Z, t=0.5, criterion='distance')
314
+
315
+ # Group features by cluster
316
+ feature_clusters = {}
317
+ for idx, cluster_id in enumerate(clusters):
318
+ feature = corr_matrix.columns[idx]
319
+ if cluster_id not in feature_clusters:
320
+ feature_clusters[cluster_id] = []
321
+ feature_clusters[cluster_id].append(feature)
322
+
323
+ # Save cluster information
324
+ self.multicollinearity_info['correlation_clusters'] = feature_clusters
325
+ logger.info(f"Correlated feature clusters detected: {len(feature_clusters)}")
326
+
327
+ except Exception as e:
328
+ logger.debug(f"Cluster analysis failed: {e}")
329
+
330
+ def _plot_correlation_analysis(
331
+ self,
332
+ data: pd.DataFrame,
333
+ corr_matrix: pd.DataFrame,
334
+ target_col: str,
335
+ high_correlations: List[Dict[str, Any]],
336
+ vif_results: Dict[str, Any]
337
+ ) -> None:
338
+ """Visualise correlation analysis"""
339
+ try:
340
+ import matplotlib.pyplot as plt
341
+ import seaborn as sns
342
+ from matplotlib import rcParams
343
+
344
+ # Style settings
345
+ plt.style.use('seaborn-v0_8-darkgrid')
346
+ rcParams.update({
347
+ 'figure.figsize': (12, 8),
348
+ 'font.size': 10,
349
+ 'axes.titlesize': 14,
350
+ 'axes.labelsize': 12
351
+ })
352
+
353
+ # Create directory
354
+ plots_dir = os.path.join(self.config.results_dir, 'plots', 'correlations')
355
+ os.makedirs(plots_dir, exist_ok=True)
356
+
357
+ # 1. Correlation matrix heatmap
358
+ if not corr_matrix.empty and corr_matrix.shape[0] > 1:
359
+ fig, ax = plt.subplots(figsize=(14, 12))
360
+
361
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
362
+ sns.heatmap(
363
+ corr_matrix,
364
+ mask=mask,
365
+ annot=True,
366
+ fmt='.2f',
367
+ cmap='coolwarm',
368
+ center=0,
369
+ square=True,
370
+ linewidths=0.5,
371
+ cbar_kws={"shrink": 0.8},
372
+ ax=ax
373
+ )
374
+ ax.set_title('Correlation Matrix (Pearson)', fontweight='bold')
375
+ plt.tight_layout()
376
+ plt.savefig(os.path.join(plots_dir, 'correlation_matrix.png'),
377
+ dpi=150, bbox_inches='tight')
378
+ plt.close()
379
+
380
+ # 2. Target variable correlations
381
+ if target_col in corr_matrix.columns:
382
+ target_corrs = corr_matrix[target_col].drop(target_col, errors='ignore')
383
+ if not target_corrs.empty:
384
+ fig, ax = plt.subplots(figsize=(10, 8))
385
+
386
+ top_corrs = target_corrs.abs().sort_values(ascending=True).tail(20)
387
+ colors = ['red' if target_corrs[feat] < 0 else 'blue'
388
+ for feat in top_corrs.index]
389
+
390
+ ax.barh(range(len(top_corrs)), top_corrs.values, color=colors)
391
+ ax.set_yticks(range(len(top_corrs)))
392
+ ax.set_yticklabels(top_corrs.index)
393
+ ax.set_xlabel('Absolute correlation')
394
+ ax.set_title(f'Top-20 correlations with {target_col}', fontweight='bold')
395
+ ax.grid(True, alpha=0.3, axis='x')
396
+
397
+ plt.tight_layout()
398
+ plt.savefig(os.path.join(plots_dir, 'target_correlations.png'),
399
+ dpi=150, bbox_inches='tight')
400
+ plt.close()
401
+
402
+ # 3. VIF scores plot
403
+ if vif_results['scores']:
404
+ valid_scores = {k: v for k, v in vif_results['scores'].items()
405
+ if not pd.isna(v)}
406
+ if valid_scores:
407
+ fig, ax = plt.subplots(figsize=(12, 8))
408
+
409
+ sorted_scores = dict(sorted(valid_scores.items(),
410
+ key=lambda x: x[1],
411
+ reverse=True)[:25])
412
+
413
+ colors = []
414
+ for vif in sorted_scores.values():
415
+ if vif > 100:
416
+ colors.append('red')
417
+ elif vif > 10:
418
+ colors.append('orange')
419
+ elif vif > 5:
420
+ colors.append('yellow')
421
+ else:
422
+ colors.append('green')
423
+
424
+ bars = ax.barh(list(sorted_scores.keys()),
425
+ list(sorted_scores.values()),
426
+ color=colors, edgecolor='black')
427
+
428
+ ax.set_xlabel('VIF Score')
429
+ ax.set_title('VIF Scores (multicollinearity)', fontweight='bold')
430
+ ax.axvline(x=5, color='yellow', linestyle='--', alpha=0.7)
431
+ ax.axvline(x=10, color='orange', linestyle='--', alpha=0.7)
432
+ ax.axvline(x=100, color='red', linestyle='--', alpha=0.7)
433
+ ax.grid(True, alpha=0.3, axis='x')
434
+
435
+ plt.tight_layout()
436
+ plt.savefig(os.path.join(plots_dir, 'vif_scores.png'),
437
+ dpi=150, bbox_inches='tight')
438
+ plt.close()
439
+
440
+ # 4. High correlations plot
441
+ if high_correlations:
442
+ fig, ax = plt.subplots(figsize=(12, 8))
443
+
444
+ # Limit number for display
445
+ display_corrs = high_correlations[:15]
446
+
447
+ # Create labels for feature pairs
448
+ labels = [f"{corr['feature1']} ↔ {corr['feature2']}"
449
+ for corr in display_corrs]
450
+ values = [corr['correlation'] for corr in display_corrs]
451
+ colors = ['red' if v < 0 else 'blue' for v in values]
452
+
453
+ y_pos = np.arange(len(display_corrs))
454
+ ax.barh(y_pos, values, color=colors)
455
+ ax.set_yticks(y_pos)
456
+ ax.set_yticklabels(labels, fontsize=9)
457
+ ax.invert_yaxis()
458
+ ax.set_xlabel('Correlation')
459
+ ax.set_title('High correlations (> 0.8)', fontweight='bold')
460
+ ax.grid(True, alpha=0.3, axis='x')
461
+
462
+ plt.tight_layout()
463
+ plt.savefig(os.path.join(plots_dir, 'high_correlations.png'),
464
+ dpi=150, bbox_inches='tight')
465
+ plt.close()
466
+
467
+ logger.info(f"Visualisations saved to {plots_dir}")
468
+
469
+ except Exception as e:
470
+ logger.warning(f"Error creating visualisations: {e}")
471
+
472
+ def _log_analysis_results(
473
+ self,
474
+ corr_matrix: pd.DataFrame,
475
+ high_correlations: List[Dict[str, Any]],
476
+ target_correlations: List[Dict[str, Any]],
477
+ vif_results: Dict[str, Any]
478
+ ) -> None:
479
+ """Log analysis results"""
480
+ logger.info("\n" + "="*80)
481
+ logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS REPORT")
482
+ logger.info("="*80)
483
+
484
+ # General information
485
+ logger.info(f"\n📊 GENERAL INFORMATION:")
486
+ logger.info(f" Correlation matrix size: {corr_matrix.shape}")
487
+ logger.info(f" Total features: {len(corr_matrix.columns)}")
488
+
489
+ # High correlations
490
+ if high_correlations:
491
+ logger.info(f"\n⚠ HIGH CORRELATIONS (|r| > 0.8): {len(high_correlations)}")
492
+ logger.info(" " + "-" * 60)
493
+
494
+ for i, corr in enumerate(high_correlations[:10]):
495
+ sign = "🟥" if corr['correlation'] < 0 else "🟩"
496
+ logger.info(f" {i+1:2d}. {sign} {corr['feature1']:25s} ↔ {corr['feature2']:25s}: {corr['correlation']:7.4f}")
497
+
498
+ if len(high_correlations) > 10:
499
+ logger.info(f" ... and {len(high_correlations) - 10} more pairs")
500
+
501
+ # Target variable correlations
502
+ if target_correlations:
503
+ logger.info(f"\n🎯 CORRELATIONS WITH TARGET VARIABLE:")
504
+ logger.info(" " + "-" * 60)
505
+
506
+ for i, corr in enumerate(target_correlations[:10]):
507
+ direction = "↓" if corr['correlation'] < 0 else "↑"
508
+ logger.info(f" {i+1:2d}. {direction} {corr['feature']:35s}: {corr['correlation']:7.4f}")
509
+
510
+ # Multicollinearity analysis
511
+ if vif_results['scores']:
512
+ logger.info(f"\n📈 MULTICOLLINEARITY ANALYSIS (VIF):")
513
+ logger.info(" " + "-" * 60)
514
+ logger.info(f" Critical (VIF > 100): {vif_results['summary']['critical']}")
515
+ logger.info(f" High (10 < VIF ≤ 100): {vif_results['summary']['high']}")
516
+ logger.info(f" Medium (5 < VIF ≤ 10): {vif_results['summary']['medium']}")
517
+ logger.info(f" Low (VIF ≤ 5): {vif_results['summary']['low']}")
518
+
519
+ # Top problematic features
520
+ if vif_results['issues']:
521
+ logger.info(f"\n🔴 PROBLEMATIC FEATURES (VIF > 10):")
522
+ for issue in vif_results['issues'][:10]:
523
+ logger.info(f" • {issue['feature']:35s}: VIF = {issue['vif']:7.1f} ({issue['severity']})")
524
+
525
+ logger.info("\n" + "="*80)
526
+ logger.info("RECOMMENDATIONS:")
527
+ logger.info("="*80)
528
+
529
+ # Generate recommendations
530
+ recommendations = []
531
+
532
+ if len(high_correlations) > 20:
533
+ recommendations.append("1. Remove highly correlated features (correlation method)")
534
+
535
+ if vif_results['summary']['critical'] > 0:
536
+ recommendations.append("2. Remove features with critical VIF (>100)")
537
+
538
+ if vif_results['summary']['high'] > 5:
539
+ recommendations.append("3. Consider removing features with VIF > 10")
540
+
541
+ if not recommendations:
542
+ recommendations.append("1. Data in good condition, no serious issues detected")
543
+ recommendations.append("2. Proceed to modelling")
544
+
545
+ for i, rec in enumerate(recommendations, 1):
546
+ logger.info(f" {rec}")
547
+
548
+ logger.info("\n" + "="*80)
549
+
550
+ def remove_highly_correlated(
551
+ self,
552
+ data: pd.DataFrame,
553
+ threshold: float = 0.85,
554
+ method: str = 'variance',
555
+ keep_target: bool = True,
556
+ keep_features: List[str] = None
557
+ ) -> pd.DataFrame:
558
+ """
559
+ Remove highly correlated features
560
+
561
+ Parameters:
562
+ -----------
563
+ data : pd.DataFrame
564
+ Source data
565
+ threshold : float
566
+ Correlation threshold for removal
567
+ method : str
568
+ Feature selection method for removal: 'variance', 'random', 'importance'
569
+ keep_target : bool
570
+ Whether to keep target variable
571
+ keep_features : List[str], optional
572
+ Features to keep
573
+
574
+ Returns:
575
+ --------
576
+ pd.DataFrame
577
+ Data after removing highly correlated features
578
+ """
579
+ logger.info("\n" + "="*80)
580
+ logger.info("REMOVING HIGHLY CORRELATED FEATURES")
581
+ logger.info("="*80)
582
+
583
+ data_clean = data.copy()
584
+
585
+ if 'pearson' not in self.correlation_matrices:
586
+ logger.warning("Correlation matrix not calculated, run analyze() first")
587
+ return data_clean
588
+
589
+ corr_matrix = self.correlation_matrices['pearson']
590
+
591
+ # Features to keep
592
+ features_to_keep = set()
593
+
594
+ if keep_target and self.config.target_column in data_clean.columns:
595
+ features_to_keep.add(self.config.target_column)
596
+
597
+ if keep_features:
598
+ for feat in keep_features:
599
+ if feat in data_clean.columns:
600
+ features_to_keep.add(feat)
601
+
602
+ # Temporal features (usually important for time series)
603
+ temporal_patterns = ['year', 'month', 'day', 'week', 'quarter',
604
+ 'hour', 'minute', 'second', 'sin', 'cos']
605
+
606
+ for col in data_clean.columns:
607
+ if any(pattern in col.lower() for pattern in temporal_patterns):
608
+ features_to_keep.add(col)
609
+
610
+ # Find highly correlated pairs
611
+ upper_triangle = corr_matrix.where(
612
+ np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
613
+ )
614
+
615
+ # Collect highly correlated features
616
+ correlated_features = set()
617
+ for col in upper_triangle.columns:
618
+ if col in features_to_keep:
619
+ continue
620
+
621
+ high_corr = upper_triangle[col][abs(upper_triangle[col]) > threshold]
622
+ for row_idx, corr_value in high_corr.items():
623
+ if not pd.isna(corr_value) and row_idx not in features_to_keep:
624
+ # Select which feature to remove
625
+ if method == 'variance':
626
+ # Remove the one with lower variance
627
+ var_col = data_clean[col].var()
628
+ var_row = data_clean[row_idx].var()
629
+ feature_to_remove = col if var_col < var_row else row_idx
630
+ elif method == 'importance':
631
+ # Remove the one with lower correlation to target variable
632
+ if self.config.target_column in corr_matrix.columns:
633
+ corr_col_target = abs(corr_matrix.loc[col, self.config.target_column])
634
+ corr_row_target = abs(corr_matrix.loc[row_idx, self.config.target_column])
635
+ feature_to_remove = col if corr_col_target < corr_row_target else row_idx
636
+ else:
637
+ # If no target, remove randomly
638
+ feature_to_remove = np.random.choice([col, row_idx])
639
+ else:
640
+ # Remove randomly
641
+ feature_to_remove = np.random.choice([col, row_idx])
642
+
643
+ correlated_features.add(feature_to_remove)
644
+
645
+ # Remove features
646
+ features_to_remove = list(correlated_features)
647
+
648
+ if features_to_remove:
649
+ data_clean = data_clean.drop(columns=features_to_remove)
650
+
651
+ logger.info(f"\n📊 REMOVAL RESULTS:")
652
+ logger.info(f" Initial feature count: {len(data.columns)}")
653
+ logger.info(f" Features removed: {len(features_to_remove)}")
654
+ logger.info(f" Final feature count: {len(data_clean.columns)}")
655
+ logger.info(f" Retained: {len(data_clean.columns)/len(data.columns)*100:.1f}%")
656
+
657
+ if features_to_remove:
658
+ logger.info(f"\n🗑️ REMOVED FEATURES:")
659
+ for i, feat in enumerate(sorted(features_to_remove)[:20]):
660
+ logger.info(f" {i+1:2d}. {feat}")
661
+ if len(features_to_remove) > 20:
662
+ logger.info(f" ... and {len(features_to_remove) - 20} more features")
663
+ else:
664
+ logger.info("✓ No highly correlated features detected, all features retained")
665
+
666
+ logger.info("="*80)
667
+ return data_clean
668
+
669
+ def get_report(self) -> Dict[str, Any]:
670
+ """Get analysis report"""
671
+ report = {
672
+ "correlation_matrix_shape": None,
673
+ "high_correlation_count": 0,
674
+ "vif_summary": {},
675
+ "target_correlation_count": 0
676
+ }
677
+
678
+ if 'pearson' in self.correlation_matrices:
679
+ report["correlation_matrix_shape"] = self.correlation_matrices['pearson'].shape
680
+
681
+ if 'pearson' in self.high_correlation_pairs:
682
+ report["high_correlation_count"] = len(self.high_correlation_pairs['pearson'])
683
+
684
+ if self.vif_scores:
685
+ report["vif_summary"] = self.vif_scores.get('summary', {})
686
+
687
+ return report
data_loader/__init__.py ADDED
File without changes
data_loader/data_loader.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 2: DATA LOADER
3
+ # ============================================
4
+ from datetime import datetime
5
+ import hashlib
6
+ import json
7
+ import traceback
8
+ from typing import Dict, List, Optional
9
+ from venv import logger
10
+ from config.config import Config, DataType
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ class DataLoader:
15
+ """Class for loading and initial data processing"""
16
+
17
+ def __init__(self, config: Config):
18
+ """
19
+ Initialise data loader
20
+
21
+ Parameters:
22
+ -----------
23
+ config : Config
24
+ Experiment configuration
25
+ """
26
+ self.config = config
27
+ self.data = None
28
+ self.metadata = {}
29
+ self.data_hash = None
30
+ self.loading_time = None
31
+ self.data_types = {}
32
+ self.original_shape = None
33
+
34
+ def load_from_csv(
35
+ self,
36
+ data_path: Optional[str] = None,
37
+ parse_dates: List[str] = None,
38
+ date_format: str = None,
39
+ dtype: Dict = None,
40
+ **kwargs
41
+ ) -> pd.DataFrame:
42
+ """
43
+ Load data from CSV file
44
+
45
+ Parameters:
46
+ -----------
47
+ data_path : str, optional
48
+ Path to CSV file. If None, uses path from configuration.
49
+ parse_dates : List[str], optional
50
+ List of columns to parse as dates
51
+ date_format : str, optional
52
+ Date format
53
+ dtype : Dict, optional
54
+ Data types for columns
55
+ **kwargs : dict
56
+ Additional parameters for pd.read_csv
57
+
58
+ Returns:
59
+ --------
60
+ pd.DataFrame
61
+ Loaded data
62
+ """
63
+ logger.info("="*80)
64
+ logger.info("LOADING DATA FROM CSV")
65
+ logger.info("="*80)
66
+
67
+ start_time = datetime.now()
68
+
69
+ try:
70
+ path = data_path or self.config.data_path
71
+
72
+ if parse_dates is None:
73
+ parse_dates = ['date']
74
+
75
+ # Load data
76
+ self.data = pd.read_csv(
77
+ path,
78
+ parse_dates=parse_dates,
79
+ dayfirst=False,
80
+ dtype=dtype,
81
+ **kwargs
82
+ )
83
+
84
+ # Convert dates if needed
85
+ for date_col in parse_dates:
86
+ if date_col in self.data.columns:
87
+ if date_format:
88
+ self.data[date_col] = pd.to_datetime(
89
+ self.data[date_col],
90
+ format=date_format,
91
+ errors='coerce'
92
+ )
93
+ else:
94
+ self.data[date_col] = pd.to_datetime(
95
+ self.data[date_col],
96
+ errors='coerce'
97
+ )
98
+
99
+ # Save original shape
100
+ self.original_shape = self.data.shape
101
+
102
+ # Filter by years
103
+ if 'date' in self.data.columns:
104
+ mask = (self.data['date'].dt.year >= self.config.start_year) & \
105
+ (self.data['date'].dt.year <= self.config.end_year)
106
+ self.data = self.data.loc[mask].copy()
107
+
108
+ # Sort by date
109
+ if 'date' in self.data.columns:
110
+ self.data = self.data.sort_values('date').reset_index(drop=True)
111
+ # Set date as index
112
+ self.data.set_index('date', inplace=True)
113
+
114
+ # Calculate data hash
115
+ self.data_hash = self._calculate_data_hash()
116
+
117
+ # Analyse data types
118
+ self._analyse_data_types()
119
+
120
+ # Save metadata
121
+ self._save_metadata()
122
+
123
+ # Loading time
124
+ self.loading_time = (datetime.now() - start_time).total_seconds()
125
+
126
+ logger.info(f"✓ Loaded {len(self.data)} records, {len(self.data.columns)} columns")
127
+ logger.info(f" Period: {self.data.index.min()} - {self.data.index.max()}")
128
+ logger.info(f" Data types: {self.data_types}")
129
+ logger.info(f" Target variable: {self.config.target_column}")
130
+ logger.info(f" Loading time: {self.loading_time:.2f} sec")
131
+
132
+ return self.data
133
+
134
+ except Exception as e:
135
+ logger.error(f"✗ Error loading data: {e}")
136
+ logger.error(traceback.format_exc())
137
+ raise
138
+
139
+ def create_synthetic_data(
140
+ self,
141
+ n_days: int = 365*21,
142
+ trend_strength: float = 0.01,
143
+ seasonal_amplitude: List[float] = None,
144
+ noise_std: float = 10,
145
+ include_exogenous: bool = True,
146
+ random_state: int = 42
147
+ ) -> pd.DataFrame:
148
+ """
149
+ Create synthetic data for testing
150
+
151
+ Parameters:
152
+ -----------
153
+ n_days : int
154
+ Number of days to generate
155
+ trend_strength : float
156
+ Trend strength
157
+ seasonal_amplitude : List[float], optional
158
+ Seasonal component amplitudes
159
+ noise_std : float
160
+ Noise standard deviation
161
+ include_exogenous : bool
162
+ Whether to include exogenous variables
163
+ random_state : int
164
+ Seed for reproducibility
165
+
166
+ Returns:
167
+ --------
168
+ pd.DataFrame
169
+ Synthetic data
170
+ """
171
+ logger.info("="*80)
172
+ logger.info("CREATING SYNTHETIC DATA")
173
+ logger.info("="*80)
174
+
175
+ if seasonal_amplitude is None:
176
+ seasonal_amplitude = [50, 30, 20]
177
+
178
+ np.random.seed(random_state)
179
+
180
+ # Generate dates
181
+ dates = pd.date_range(
182
+ start=f'{self.config.start_year}-01-01',
183
+ periods=n_days,
184
+ freq='D'
185
+ )
186
+
187
+ t = np.arange(n_days)
188
+
189
+ # Base components
190
+ trend = trend_strength * t
191
+
192
+ # Seasonal components
193
+ seasonal = 0
194
+ periods = [365, 30, 7] # yearly, monthly, weekly seasonality
195
+ for i, (period, amplitude) in enumerate(zip(periods, seasonal_amplitude)):
196
+ seasonal += amplitude * np.sin(2 * np.pi * t / period)
197
+ if i < len(seasonal_amplitude) - 1:
198
+ seasonal += 0.5 * amplitude * np.cos(4 * np.pi * t / period)
199
+
200
+ # Cyclical component (business cycles)
201
+ cycle = 20 * np.sin(2 * np.pi * t / (365*5)) # 5-year cycle
202
+
203
+ # Noise
204
+ noise = np.random.normal(0, noise_std, n_days)
205
+
206
+ # Generate target variable
207
+ raskhodvoda = 100 + trend + seasonal + cycle + noise
208
+
209
+ # Create DataFrame
210
+ self.data = pd.DataFrame(
211
+ index=dates,
212
+ data={'raskhodvoda': raskhodvoda}
213
+ )
214
+
215
+ # Generate exogenous variables
216
+ if include_exogenous:
217
+ # Temperature with seasonality
218
+ tavg = 10 + 8 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 3, n_days)
219
+ tmin = tavg - 5 + np.random.normal(0, 2, n_days)
220
+ tmax = tavg + 5 + np.random.normal(0, 2, n_days)
221
+
222
+ # Water level with trend and seasonality
223
+ urovenvoda = 200 + 0.5 * t + 20 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 5, n_days)
224
+
225
+ # Add to DataFrame
226
+ self.data['tavg'] = tavg
227
+ self.data['tmin'] = tmin
228
+ self.data['tmax'] = tmax
229
+ self.data['urovenvoda'] = urovenvoda
230
+
231
+ # Add noisy lags
232
+ for lag in [1, 7, 30]:
233
+ self.data[f'tavg_lag_{lag}'] = self.data['tavg'].shift(lag) + np.random.normal(0, 1, n_days)
234
+
235
+ # Add missing values and outliers for testing
236
+ if n_days > 100:
237
+ # Missing values (5% of data)
238
+ mask_missing = np.random.random(n_days) < 0.05
239
+ self.data.loc[mask_missing, 'tavg'] = np.nan
240
+
241
+ # Outliers (1% of data)
242
+ mask_outliers = np.random.random(n_days) < 0.01
243
+ self.data.loc[mask_outliers, 'raskhodvoda'] *= 2
244
+
245
+ # Save metadata
246
+ self.metadata.update({
247
+ 'is_synthetic': True,
248
+ 'synthetic_params': {
249
+ 'n_days': n_days,
250
+ 'trend_strength': trend_strength,
251
+ 'seasonal_amplitude': seasonal_amplitude,
252
+ 'noise_std': noise_std,
253
+ 'include_exogenous': include_exogenous,
254
+ 'random_state': random_state
255
+ }
256
+ })
257
+
258
+ logger.info(f"✓ Created {len(self.data)} synthetic records")
259
+ logger.info(f" Columns: {list(self.data.columns)}")
260
+
261
+ return self.data
262
+
263
+ def _calculate_data_hash(self) -> str:
264
+ """Calculate data hash for tracking changes"""
265
+ if self.data is None:
266
+ return None
267
+
268
+ # Use hash of first 1000 rows and metadata
269
+ sample = self.data.head(1000).to_string().encode()
270
+ return hashlib.md5(sample).hexdigest()
271
+
272
+ def _analyse_data_types(self) -> None:
273
+ """Analyse data types in DataFrame"""
274
+ if self.data is None:
275
+ return
276
+
277
+ for col in self.data.columns:
278
+ dtype = str(self.data[col].dtype)
279
+
280
+ if 'datetime' in dtype:
281
+ self.data_types[col] = DataType.TEMPORAL.value
282
+ elif 'int' in dtype or 'float' in dtype:
283
+ self.data_types[col] = DataType.NUMERIC.value
284
+ elif 'object' in dtype or 'category' in dtype:
285
+ # Check if categorical
286
+ unique_ratio = self.data[col].nunique() / len(self.data)
287
+ if unique_ratio < 0.1: # Less than 10% unique values
288
+ self.data_types[col] = DataType.CATEGORICAL.value
289
+ else:
290
+ self.data_types[col] = DataType.TEXT.value
291
+ else:
292
+ self.data_types[col] = 'unknown'
293
+
294
+ def _save_metadata(self) -> None:
295
+ """Save data metadata"""
296
+ if self.data is None:
297
+ return
298
+
299
+ # Basic metadata
300
+ self.metadata.update({
301
+ 'original_shape': list(self.original_shape) if self.original_shape else [],
302
+ 'current_shape': list(self.data.shape),
303
+ 'columns': list(self.data.columns),
304
+ 'data_types': self.data_types,
305
+ 'date_range': {
306
+ 'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
307
+ 'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
308
+ },
309
+ 'data_hash': self.data_hash,
310
+ 'loading_time': self.loading_time
311
+ })
312
+
313
+ # Statistics for numeric columns
314
+ numeric_cols = self.data.select_dtypes(include=[np.number]).columns
315
+ if len(numeric_cols) > 0:
316
+ stats = self.data[numeric_cols].describe().to_dict()
317
+ # Add additional statistics
318
+ for col in numeric_cols:
319
+ stats[col]['skewness'] = float(self.data[col].skew())
320
+ stats[col]['kurtosis'] = float(self.data[col].kurtosis())
321
+ stats[col]['cv'] = float(self.data[col].std() / self.data[col].mean()) if self.data[col].mean() != 0 else np.nan
322
+
323
+ self.metadata['numeric_statistics'] = stats
324
+
325
+ # Missing values information
326
+ missing_info = {
327
+ 'total_missing': int(self.data.isnull().sum().sum()),
328
+ 'missing_by_column': self.data.isnull().sum().to_dict(),
329
+ 'missing_percentage': (self.data.isnull().sum() / len(self.data) * 100).to_dict(),
330
+ 'rows_with_missing': int(self.data.isnull().any(axis=1).sum()),
331
+ 'columns_with_missing': self.data.columns[self.data.isnull().any()].tolist()
332
+ }
333
+ self.metadata['missing_info'] = missing_info
334
+
335
+ def get_data_info(self) -> Dict:
336
+ """Get information about data"""
337
+ if self.data is None:
338
+ return {}
339
+
340
+ info = {
341
+ 'shape': list(self.data.shape),
342
+ 'columns': list(self.data.columns),
343
+ 'data_types': self.data_types,
344
+ 'date_range': {
345
+ 'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
346
+ 'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
347
+ },
348
+ 'target_column': self.config.target_column,
349
+ 'numeric_columns': self.data.select_dtypes(include=[np.number]).columns.tolist(),
350
+ 'categorical_columns': [col for col, dtype in self.data_types.items()
351
+ if dtype == DataType.CATEGORICAL.value],
352
+ 'missing_info': self.metadata.get('missing_info', {})
353
+ }
354
+
355
+ return info
356
+
357
+ def save_raw_data_info(self) -> None:
358
+ """Save raw data information"""
359
+ if self.data is None:
360
+ return
361
+
362
+ info_path = f'{self.config.results_dir}/reports/raw_data_info.json'
363
+
364
+ # Custom JSON encoder for handling numpy types
365
+ class NumpyEncoder(json.JSONEncoder):
366
+ def default(self, obj):
367
+ if isinstance(obj, (np.integer, np.floating)):
368
+ if np.isnan(obj):
369
+ return None
370
+ return float(obj)
371
+ elif isinstance(obj, np.bool_):
372
+ return bool(obj)
373
+ elif isinstance(obj, np.ndarray):
374
+ return obj.tolist()
375
+ elif isinstance(obj, pd.Timestamp):
376
+ return obj.strftime('%Y-%m-%d %H:%M:%S')
377
+ elif isinstance(obj, pd.Period):
378
+ return str(obj)
379
+ return super().default(obj)
380
+
381
+ with open(info_path, 'w', encoding='utf-8') as f:
382
+ json.dump(self.metadata, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
383
+
384
+ logger.info(f"✓ Raw data information saved: {info_path}")
385
+
386
+ def resample_data(
387
+ self,
388
+ freq: str = None,
389
+ method: str = 'mean'
390
+ ) -> pd.DataFrame:
391
+ """
392
+ Resample time series data
393
+
394
+ Parameters:
395
+ -----------
396
+ freq : str, optional
397
+ New frequency (e.g., 'D', 'W', 'M')
398
+ method : str
399
+ Aggregation method: 'mean', 'sum', 'last', 'first'
400
+
401
+ Returns:
402
+ --------
403
+ pd.DataFrame
404
+ Resampled data
405
+ """
406
+ if self.data is None:
407
+ logger.warning("Data not loaded")
408
+ return None
409
+
410
+ freq = freq or self.config.freq
411
+
412
+ # Check if index is datetime
413
+ if not isinstance(self.data.index, pd.DatetimeIndex):
414
+ logger.error("Data index is not DatetimeIndex")
415
+ return self.data
416
+
417
+ # Aggregation methods
418
+ agg_methods = {
419
+ 'mean': np.mean,
420
+ 'sum': np.sum,
421
+ 'last': lambda x: x.iloc[-1],
422
+ 'first': lambda x: x.iloc[0],
423
+ 'min': np.min,
424
+ 'max': np.max,
425
+ 'median': np.median
426
+ }
427
+
428
+ if method not in agg_methods:
429
+ logger.warning(f"Method {method} not supported, using mean")
430
+ method = 'mean'
431
+
432
+ # Resampling
433
+ try:
434
+ if method == 'last':
435
+ resampled_data = self.data.resample(freq).last()
436
+ elif method == 'first':
437
+ resampled_data = self.data.resample(freq).first()
438
+ else:
439
+ resampled_data = self.data.resample(freq).agg(agg_methods[method])
440
+
441
+ logger.info(f"Data resampled to frequency {freq}, method {method}")
442
+ logger.info(f"Size before: {len(self.data)}, after: {len(resampled_data)}")
443
+
444
+ self.data = resampled_data
445
+ return self.data
446
+
447
+ except Exception as e:
448
+ logger.error(f"Error during resampling: {e}")
449
+ return self.data
450
+
451
+ def detect_frequency(self) -> str:
452
+ """
453
+ Automatically detect data frequency
454
+
455
+ Returns:
456
+ --------
457
+ str
458
+ Detected data frequency
459
+ """
460
+ if self.data is None or len(self.data) < 2:
461
+ return 'unknown'
462
+
463
+ if not isinstance(self.data.index, pd.DatetimeIndex):
464
+ return 'irregular'
465
+
466
+ # Calculate differences between timestamps
467
+ diffs = pd.Series(self.data.index).diff().dropna()
468
+
469
+ if len(diffs) == 0:
470
+ return 'unknown'
471
+
472
+ # Most frequent difference
473
+ mode_diff = diffs.mode().iloc[0] if not diffs.mode().empty else diffs.iloc[0]
474
+
475
+ # Determine frequency
476
+ if mode_diff < pd.Timedelta('1 hour'):
477
+ return 'H' # Hourly
478
+ elif mode_diff < pd.Timedelta('1 day'):
479
+ return 'D' # Daily
480
+ elif mode_diff < pd.Timedelta('7 days'):
481
+ return 'W' # Weekly
482
+ elif mode_diff < pd.Timedelta('30 days'):
483
+ return 'M' # Monthly
484
+ elif mode_diff < pd.Timedelta('90 days'):
485
+ return 'Q' # Quarterly
486
+ else:
487
+ return 'Y' # Yearly
decomposition/__init__.py ADDED
File without changes
decomposition/decomposer.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 7: TIME SERIES DECOMPOSITION
3
+ # ============================================
4
+ import traceback
5
+ from typing import Dict, Optional
6
+ from venv import logger
7
+
8
+ from config.config import Config
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import statsmodels.api as sm
14
+ from scipy import stats
15
+ from statsmodels.tsa.seasonal import seasonal_decompose, STL
16
+ from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
17
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
18
+ from statsmodels.stats.diagnostic import acorr_ljungbox
19
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
20
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing
21
+
22
+
23
+ class TimeSeriesDecomposer:
24
+ """Class for time series decomposition"""
25
+
26
+ def __init__(self, config: Config):
27
+ """
28
+ Initialise decomposer
29
+
30
+ Parameters:
31
+ -----------
32
+ config : Config
33
+ Experiment configuration
34
+ """
35
+ self.config = config
36
+ self.decomposition_results = {}
37
+ self.decomposition_models = {}
38
+ self.seasonal_periods = {}
39
+
40
+ def decompose(
41
+ self,
42
+ data: pd.DataFrame,
43
+ target_col: Optional[str] = None,
44
+ method: str = 'stl',
45
+ period: Optional[int] = None,
46
+ **kwargs
47
+ ) -> Dict:
48
+ """
49
+ Decompose time series into components
50
+
51
+ Parameters:
52
+ -----------
53
+ data : pd.DataFrame
54
+ Input data
55
+ target_col : str, optional
56
+ Target variable. If None, uses configuration value.
57
+ method : str
58
+ Decomposition model: 'stl', 'seasonal_decompose', 'mstl', 'naive'
59
+ period : int, optional
60
+ Seasonality period. If None, uses configuration value.
61
+ **kwargs : dict
62
+ Additional parameters for method
63
+
64
+ Returns:
65
+ --------
66
+ Dict
67
+ Decomposition results
68
+ """
69
+ logger.info("\n" + "="*80)
70
+ logger.info("TIME SERIES DECOMPOSITION")
71
+ logger.info("="*80)
72
+
73
+ target_col = target_col or self.config.target_column
74
+ period = period or self.config.seasonal_period
75
+
76
+ if target_col not in data.columns:
77
+ logger.error(f"Target variable '{target_col}' not found")
78
+ return {}
79
+
80
+ # Set date as index if not set
81
+ if not isinstance(data.index, pd.DatetimeIndex):
82
+ if 'date' in data.columns:
83
+ data = data.set_index('date')
84
+ else:
85
+ logger.error("DatetimeIndex required for decomposition")
86
+ return {}
87
+
88
+ series = data[target_col]
89
+
90
+ # Automatic seasonality period detection
91
+ if period is None or period == 'auto':
92
+ period = self._detect_seasonal_period(series)
93
+ logger.info(f"Automatically detected seasonality period: {period}")
94
+
95
+ try:
96
+ decomposition_result = None
97
+
98
+ if method == 'stl':
99
+ decomposition_result = self._stl_decomposition(series, period, **kwargs)
100
+ elif method == 'seasonal_decompose':
101
+ decomposition_result = self._seasonal_decompose(series, period, **kwargs)
102
+ elif method == 'mstl':
103
+ decomposition_result = self._mstl_decomposition(series, **kwargs)
104
+ elif method == 'naive':
105
+ decomposition_result = self._naive_decomposition(series, period, **kwargs)
106
+ else:
107
+ logger.warning(f"Method {method} not supported, using STL")
108
+ decomposition_result = self._stl_decomposition(series, period, **kwargs)
109
+
110
+ if decomposition_result is None:
111
+ logger.error("Decomposition failed")
112
+ return {}
113
+
114
+ # Analyse residuals
115
+ residuals_info = self._analyse_residuals(decomposition_result.get('residual', None))
116
+
117
+ # Analyse seasonality
118
+ seasonal_info = self._analyse_seasonality(
119
+ decomposition_result.get('seasonal', None),
120
+ period
121
+ )
122
+
123
+ # Save results
124
+ self.decomposition_results[target_col] = {
125
+ 'method': method,
126
+ 'period': period,
127
+ 'residuals_analysis': residuals_info,
128
+ 'seasonality_analysis': seasonal_info,
129
+ 'components_present': list(decomposition_result.keys()),
130
+ 'decomposition_stats': {
131
+ 'trend_strength': self._calculate_trend_strength(
132
+ decomposition_result.get('trend', None),
133
+ decomposition_result.get('residual', None)
134
+ ),
135
+ 'seasonal_strength': self._calculate_seasonal_strength(
136
+ decomposition_result.get('seasonal', None),
137
+ decomposition_result.get('residual', None)
138
+ )
139
+ }
140
+ }
141
+
142
+ # Visualisation
143
+ if self.config.save_plots:
144
+ self._plot_decomposition(data, target_col, decomposition_result, method, period)
145
+
146
+ # Additional visualisation
147
+ if residuals_info:
148
+ self._plot_residuals_analysis(decomposition_result.get('residual', None), target_col)
149
+
150
+ return self.decomposition_results[target_col]
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error during decomposition: {e}")
154
+ logger.error(traceback.format_exc())
155
+ return {}
156
+
157
+ def _detect_seasonal_period(self, series: pd.Series) -> int:
158
+ """Automatic seasonality period detection"""
159
+ if len(series) < 100:
160
+ return self.config.seasonal_period
161
+
162
+ try:
163
+ # Use autocorrelation to determine period
164
+ acf_values = acf(series.dropna(), nlags=min(500, len(series)//2))
165
+
166
+ # Find peaks in autocorrelation
167
+ peaks = []
168
+ for i in range(1, len(acf_values)-1):
169
+ if acf_values[i] > acf_values[i-1] and acf_values[i] > acf_values[i+1]:
170
+ if acf_values[i] > 0.3: # Significance threshold
171
+ peaks.append(i)
172
+
173
+ if peaks:
174
+ # Take most significant period
175
+ dominant_period = peaks[0]
176
+
177
+ # Check for multiple periods
178
+ for period in [7, 30, 90, 365]:
179
+ if abs(dominant_period - period) <= 2:
180
+ return period
181
+
182
+ return dominant_period
183
+
184
+ return self.config.seasonal_period
185
+
186
+ except:
187
+ return self.config.seasonal_period
188
+
189
+ def _stl_decomposition(
190
+ self,
191
+ series: pd.Series,
192
+ period: int,
193
+ **kwargs
194
+ ) -> Optional[Dict]:
195
+ """STL decomposition"""
196
+ try:
197
+ if len(series) < 2 * period:
198
+ logger.warning(f"Insufficient data for STL decomposition with period {period}")
199
+ return self._seasonal_decompose(series, period, **kwargs)
200
+
201
+ # STL decomposition
202
+ stl = STL(
203
+ series,
204
+ period=period,
205
+ seasonal=kwargs.get('seasonal', 7),
206
+ trend=kwargs.get('trend', None),
207
+ robust=kwargs.get('robust', True),
208
+ seasonal_deg=kwargs.get('seasonal_deg', 1),
209
+ trend_deg=kwargs.get('trend_deg', 1),
210
+ low_pass_deg=kwargs.get('low_pass_deg', 1)
211
+ )
212
+
213
+ result = stl.fit()
214
+
215
+ return {
216
+ 'trend': result.trend,
217
+ 'seasonal': result.seasonal,
218
+ 'residual': result.resid,
219
+ 'observed': series
220
+ }
221
+
222
+ except Exception as e:
223
+ logger.warning(f"STL decomposition failed: {e}")
224
+ return self._seasonal_decompose(series, period, **kwargs)
225
+
226
+ def _seasonal_decompose(
227
+ self,
228
+ series: pd.Series,
229
+ period: int,
230
+ **kwargs
231
+ ) -> Optional[Dict]:
232
+ """Seasonal decomposition from statsmodels"""
233
+ try:
234
+ model = kwargs.get('model', 'additive')
235
+
236
+ if len(series) < 2 * period:
237
+ # Reduce period if insufficient data
238
+ period = max(7, len(series) // 4)
239
+
240
+ decomposition = seasonal_decompose(
241
+ series,
242
+ model=model,
243
+ period=period,
244
+ extrapolate_trend=kwargs.get('extrapolate_trend', 'freq'),
245
+ two_sided=kwargs.get('two_sided', True)
246
+ )
247
+
248
+ return {
249
+ 'trend': decomposition.trend,
250
+ 'seasonal': decomposition.seasonal,
251
+ 'residual': decomposition.resid,
252
+ 'observed': series
253
+ }
254
+
255
+ except Exception as e:
256
+ logger.warning(f"Seasonal decompose failed: {e}")
257
+ return self._naive_decomposition(series, period, **kwargs)
258
+
259
+ def _mstl_decomposition(
260
+ self,
261
+ series: pd.Series,
262
+ **kwargs
263
+ ) -> Optional[Dict]:
264
+ """Multi-seasonal decomposition (simplified)"""
265
+ try:
266
+ # Simplified MSTL version
267
+ periods = kwargs.get('periods', [7, 365])
268
+
269
+ result = {
270
+ 'observed': series,
271
+ 'trend': None,
272
+ 'seasonal': pd.Series(0, index=series.index),
273
+ 'residual': series.copy()
274
+ }
275
+
276
+ # Sequentially remove seasonal components
277
+ for period in periods:
278
+ if len(series) >= 2 * period:
279
+ try:
280
+ decomp = seasonal_decompose(
281
+ result['residual'],
282
+ model='additive',
283
+ period=period,
284
+ extrapolate_trend='freq'
285
+ )
286
+
287
+ if result['trend'] is None:
288
+ result['trend'] = decomp.trend
289
+
290
+ result['seasonal'] = result['seasonal'] + decomp.seasonal
291
+ result['residual'] = decomp.resid
292
+ except:
293
+ continue
294
+
295
+ if result['trend'] is None:
296
+ result['trend'] = series.rolling(window=min(365, len(series)//4), center=True).mean()
297
+
298
+ return result
299
+
300
+ except Exception as e:
301
+ logger.warning(f"MSTL decomposition failed: {e}")
302
+ return self._seasonal_decompose(series, 365, **kwargs)
303
+
304
+ def _naive_decomposition(
305
+ self,
306
+ series: pd.Series,
307
+ period: int,
308
+ **kwargs
309
+ ) -> Optional[Dict]:
310
+ """Naive decomposition"""
311
+ try:
312
+ # Simple decomposition using moving averages
313
+ trend = series.rolling(
314
+ window=min(period, len(series)//4),
315
+ center=True,
316
+ min_periods=1
317
+ ).mean()
318
+
319
+ # Seasonal component
320
+ if period > 1:
321
+ # Average by seasons
322
+ seasonal = series.groupby(series.index.dayofyear if period == 365 else
323
+ series.index.dayofweek if period == 7 else
324
+ series.index.month).transform('mean')
325
+ seasonal = seasonal - seasonal.mean()
326
+ else:
327
+ seasonal = pd.Series(0, index=series.index)
328
+
329
+ residual = series - trend - seasonal
330
+
331
+ return {
332
+ 'trend': trend,
333
+ 'seasonal': seasonal,
334
+ 'residual': residual,
335
+ 'observed': series
336
+ }
337
+
338
+ except Exception as e:
339
+ logger.error(f"Naive decomposition failed: {e}")
340
+ return None
341
+
342
+ def _analyse_residuals(self, residuals) -> Dict:
343
+ """Analyse decomposition residuals"""
344
+ if residuals is None:
345
+ return {}
346
+
347
+ residuals_clean = residuals.dropna()
348
+
349
+ if len(residuals_clean) == 0:
350
+ return {}
351
+
352
+ stats_info = {
353
+ 'mean': float(residuals_clean.mean()),
354
+ 'std': float(residuals_clean.std()),
355
+ 'skewness': float(residuals_clean.skew()),
356
+ 'kurtosis': float(residuals_clean.kurtosis()),
357
+ 'min': float(residuals_clean.min()),
358
+ 'max': float(residuals_clean.max()),
359
+ 'mad': float((residuals_clean - residuals_clean.mean()).abs().mean()),
360
+ 'normality_tests': {},
361
+ 'autocorrelation_tests': {}
362
+ }
363
+
364
+ # Normality test
365
+ if len(residuals_clean) > 3:
366
+ try:
367
+ # Shapiro-Wilk test
368
+ shapiro_stat, shapiro_p = stats.shapiro(residuals_clean.iloc[:5000])
369
+ stats_info['normality_tests']['shapiro_wilk'] = {
370
+ 'statistic': float(shapiro_stat),
371
+ 'pvalue': float(shapiro_p),
372
+ 'is_normal': shapiro_p > 0.05
373
+ }
374
+
375
+ # Anderson-Darling test
376
+ anderson_result = stats.anderson(residuals_clean, dist='norm')
377
+ stats_info['normality_tests']['anderson_darling'] = {
378
+ 'statistic': float(anderson_result.statistic),
379
+ 'critical_values': {str(level): float(value)
380
+ for level, value in zip(anderson_result.significance_level,
381
+ anderson_result.critical_values)},
382
+ 'is_normal': anderson_result.statistic < anderson_result.critical_values[2] # At 5% level
383
+ }
384
+ except:
385
+ stats_info['normality_tests']['error'] = 'not enough data or calculation error'
386
+
387
+ # Autocorrelation test
388
+ try:
389
+ # Ljung-Box test
390
+ lb_test = acorr_ljungbox(residuals_clean, lags=[10, 20, 30], return_df=True)
391
+
392
+ autocorr_info = {}
393
+ for idx, row in lb_test.iterrows():
394
+ autocorr_info[f'lag_{int(row.name)}'] = {
395
+ 'statistic': float(row['lb_stat']),
396
+ 'pvalue': float(row['lb_pvalue']),
397
+ 'has_autocorrelation': row['lb_pvalue'] < 0.05
398
+ }
399
+
400
+ stats_info['autocorrelation_tests']['ljung_box'] = autocorr_info
401
+
402
+ # Durbin-Watson test
403
+ try:
404
+ dw_stat = sm.stats.stattools.durbin_watson(residuals_clean)
405
+ stats_info['autocorrelation_tests']['durbin_watson'] = {
406
+ 'statistic': float(dw_stat),
407
+ 'interpretation': 'no autocorrelation' if 1.5 < dw_stat < 2.5 else
408
+ 'positive autocorrelation' if dw_stat < 1.5 else
409
+ 'negative autocorrelation'
410
+ }
411
+ except:
412
+ pass
413
+
414
+ except:
415
+ stats_info['autocorrelation_tests']['error'] = 'calculation error'
416
+
417
+ # Heteroskedasticity test
418
+ try:
419
+ # ARCH test
420
+ from statsmodels.stats.diagnostic import het_arch
421
+ arch_test = het_arch(residuals_clean)
422
+ stats_info['heteroskedasticity_tests'] = {
423
+ 'arch': {
424
+ 'statistic': float(arch_test[0]),
425
+ 'pvalue': float(arch_test[1]),
426
+ 'is_homoskedastic': arch_test[1] > 0.05
427
+ }
428
+ }
429
+ except:
430
+ pass
431
+
432
+ return stats_info
433
+
434
+ def _analyse_seasonality(self, seasonal_component, period: int) -> Dict:
435
+ """Analyse seasonal component"""
436
+ if seasonal_component is None:
437
+ return {}
438
+
439
+ seasonal_clean = seasonal_component.dropna()
440
+
441
+ if len(seasonal_clean) == 0:
442
+ return {}
443
+
444
+ analysis = {
445
+ 'period': period,
446
+ 'amplitude': float(seasonal_clean.max() - seasonal_clean.min()),
447
+ 'mean_amplitude': float(seasonal_clean.abs().mean()),
448
+ 'seasonal_strength': float(seasonal_clean.std()),
449
+ 'periodicity_check': {}
450
+ }
451
+
452
+ # Check periodicity via autocorrelation
453
+ if len(seasonal_clean) > period * 2:
454
+ try:
455
+ acf_values = acf(seasonal_clean, nlags=min(period * 3, len(seasonal_clean)//2))
456
+
457
+ # Look for peaks at expected lags
458
+ expected_lags = [period, period*2]
459
+ peaks_found = []
460
+
461
+ for lag in expected_lags:
462
+ if lag < len(acf_values):
463
+ if acf_values[lag] > 0.5: # Strong autocorrelation at period
464
+ peaks_found.append({
465
+ 'lag': lag,
466
+ 'autocorrelation': float(acf_values[lag]),
467
+ 'is_significant': True
468
+ })
469
+
470
+ analysis['periodicity_check']['autocorrelation_peaks'] = peaks_found
471
+ analysis['periodicity_check']['is_periodic'] = len(peaks_found) > 0
472
+ except:
473
+ pass
474
+
475
+ # Seasonality pattern analysis
476
+ if isinstance(seasonal_clean.index, pd.DatetimeIndex):
477
+ try:
478
+ # Group by months/week days
479
+ if period == 12 or period == 365:
480
+ # Monthly seasonality
481
+ monthly_seasonal = seasonal_clean.groupby(seasonal_clean.index.month).mean()
482
+ analysis['monthly_pattern'] = monthly_seasonal.to_dict()
483
+
484
+ if period == 7 or period == 365:
485
+ # Daily seasonality
486
+ daily_seasonal = seasonal_clean.groupby(seasonal_clean.index.dayofweek).mean()
487
+ analysis['daily_pattern'] = daily_seasonal.to_dict()
488
+ except:
489
+ pass
490
+
491
+ return analysis
492
+
493
+ def _calculate_trend_strength(self, trend, residual) -> float:
494
+ """Calculate trend strength"""
495
+ if trend is None or residual is None:
496
+ return 0.0
497
+
498
+ trend_clean = trend.dropna()
499
+ residual_clean = residual.dropna()
500
+
501
+ if len(trend_clean) == 0 or len(residual_clean) == 0:
502
+ return 0.0
503
+
504
+ # Trend strength = 1 - Var(residual) / Var(trend + residual)
505
+ try:
506
+ var_total = np.var(trend_clean + residual_clean)
507
+ if var_total > 0:
508
+ trend_strength = 1 - np.var(residual_clean) / var_total
509
+ return max(0.0, min(1.0, float(trend_strength)))
510
+ except:
511
+ pass
512
+
513
+ return 0.0
514
+
515
+ def _calculate_seasonal_strength(self, seasonal, residual) -> float:
516
+ """Calculate seasonality strength"""
517
+ if seasonal is None or residual is None:
518
+ return 0.0
519
+
520
+ seasonal_clean = seasonal.dropna()
521
+ residual_clean = residual.dropna()
522
+
523
+ if len(seasonal_clean) == 0 or len(residual_clean) == 0:
524
+ return 0.0
525
+
526
+ # Seasonality strength = 1 - Var(residual) / Var(seasonal + residual)
527
+ try:
528
+ var_total = np.var(seasonal_clean + residual_clean)
529
+ if var_total > 0:
530
+ seasonal_strength = 1 - np.var(residual_clean) / var_total
531
+ return max(0.0, min(1.0, float(seasonal_strength)))
532
+ except:
533
+ pass
534
+
535
+ return 0.0
536
+
537
+ def _plot_decomposition(
538
+ self,
539
+ data: pd.DataFrame,
540
+ target_col: str,
541
+ decomposition: Dict,
542
+ method: str,
543
+ period: int
544
+ ) -> None:
545
+ """Visualise decomposition"""
546
+ fig, axes = plt.subplots(4, 1, figsize=(14, 12))
547
+
548
+ # Original series
549
+ axes[0].plot(decomposition.get('observed', pd.Series()))
550
+ axes[0].set_ylabel('Observed')
551
+ axes[0].set_title(f'Time Series Decomposition: {target_col} ({method}, period={period})')
552
+ axes[0].grid(True, alpha=0.3)
553
+
554
+ # Trend
555
+ if 'trend' in decomposition and decomposition['trend'] is not None:
556
+ axes[1].plot(decomposition['trend'])
557
+ axes[1].set_ylabel('Trend')
558
+ axes[1].grid(True, alpha=0.3)
559
+
560
+ # Seasonality
561
+ if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
562
+ axes[2].plot(decomposition['seasonal'])
563
+ axes[2].set_ylabel('Seasonality')
564
+ axes[2].grid(True, alpha=0.3)
565
+
566
+ # Residuals
567
+ if 'residual' in decomposition and decomposition['residual'] is not None:
568
+ axes[3].plot(decomposition['residual'])
569
+ axes[3].set_ylabel('Residuals')
570
+ axes[3].set_xlabel('Date')
571
+ axes[3].grid(True, alpha=0.3)
572
+
573
+ plt.tight_layout()
574
+ plt.savefig(
575
+ f'{self.config.results_dir}/plots/decomposition_{target_col}.png',
576
+ dpi=300,
577
+ bbox_inches='tight'
578
+ )
579
+ plt.show()
580
+
581
+ # Additional plots
582
+ self._plot_decomposition_components(data, target_col, decomposition)
583
+
584
+ def _plot_decomposition_components(
585
+ self,
586
+ data: pd.DataFrame,
587
+ target_col: str,
588
+ decomposition: Dict
589
+ ) -> None:
590
+ """Visualise decomposition components"""
591
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
592
+
593
+ # 1. Sum of components vs original series
594
+ if all(k in decomposition for k in ['trend', 'seasonal', 'residual']):
595
+ reconstructed = decomposition['trend'] + decomposition['seasonal'] + decomposition['residual']
596
+ axes[0, 0].plot(decomposition['observed'], alpha=0.7, label='Original')
597
+ axes[0, 0].plot(reconstructed, alpha=0.7, label='Reconstructed')
598
+ axes[0, 0].set_title('Original vs Reconstructed Series')
599
+ axes[0, 0].set_xlabel('Date')
600
+ axes[0, 0].set_ylabel(target_col)
601
+ axes[0, 0].legend()
602
+ axes[0, 0].grid(True, alpha=0.3)
603
+
604
+ # 2. Residuals distribution
605
+ if 'residual' in decomposition and decomposition['residual'] is not None:
606
+ residuals = decomposition['residual'].dropna()
607
+ axes[0, 1].hist(residuals, bins=30, edgecolor='black', alpha=0.7, density=True)
608
+
609
+ # Normal distribution for comparison
610
+ xmin, xmax = axes[0, 1].get_xlim()
611
+ x = np.linspace(xmin, xmax, 100)
612
+ p = stats.norm.pdf(x, residuals.mean(), residuals.std())
613
+ axes[0, 1].plot(x, p, 'k', linewidth=2, label='Normal distribution')
614
+
615
+ axes[0, 1].set_title('Residuals Distribution')
616
+ axes[0, 1].set_xlabel('Residuals')
617
+ axes[0, 1].set_ylabel('Density')
618
+ axes[0, 1].legend()
619
+ axes[0, 1].grid(True, alpha=0.3)
620
+
621
+ # 3. ACF of residuals
622
+ if 'residual' in decomposition and decomposition['residual'] is not None:
623
+ plot_acf(decomposition['residual'].dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
624
+ axes[1, 0].set_title('Residuals ACF')
625
+ axes[1, 0].set_xlabel('Lag')
626
+ axes[1, 0].set_ylabel('Autocorrelation')
627
+ axes[1, 0].grid(True, alpha=0.3)
628
+
629
+ # 4. Seasonal pattern
630
+ if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
631
+ seasonal = decomposition['seasonal']
632
+ if isinstance(seasonal.index, pd.DatetimeIndex):
633
+ # Group by months
634
+ try:
635
+ monthly_seasonal = seasonal.groupby(seasonal.index.month).mean()
636
+ axes[1, 1].bar(monthly_seasonal.index, monthly_seasonal.values)
637
+ axes[1, 1].set_title('Average Seasonal Pattern by Month')
638
+ axes[1, 1].set_xlabel('Month')
639
+ axes[1, 1].set_ylabel('Seasonality')
640
+ axes[1, 1].set_xticks(range(1, 13))
641
+ axes[1, 1].grid(True, alpha=0.3)
642
+ except:
643
+ axes[1, 1].plot(seasonal.index, seasonal.values)
644
+ axes[1, 1].set_title('Seasonal Component')
645
+ axes[1, 1].grid(True, alpha=0.3)
646
+
647
+ plt.tight_layout()
648
+ plt.savefig(
649
+ f'{self.config.results_dir}/plots/decomposition_components_{target_col}.png',
650
+ dpi=300,
651
+ bbox_inches='tight'
652
+ )
653
+ plt.show()
654
+
655
+ def _plot_residuals_analysis(self, residuals, target_col: str) -> None:
656
+ """Visualise residual analysis"""
657
+ if residuals is None:
658
+ return
659
+
660
+ residuals_clean = residuals.dropna()
661
+
662
+ if len(residuals_clean) == 0:
663
+ return
664
+
665
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
666
+
667
+ # Q-Q plot
668
+ stats.probplot(residuals_clean, dist="norm", plot=axes[0])
669
+ axes[0].set_title('Residuals Q-Q plot')
670
+ axes[0].grid(True, alpha=0.3)
671
+
672
+ # Residuals over time
673
+ axes[1].plot(residuals_clean.index, residuals_clean.values, linewidth=0.5)
674
+ axes[1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
675
+ axes[1].set_title('Residuals Over Time')
676
+ axes[1].set_xlabel('Date')
677
+ axes[1].set_ylabel('Residuals')
678
+ axes[1].grid(True, alpha=0.3)
679
+
680
+ plt.tight_layout()
681
+ plt.savefig(
682
+ f'{self.config.results_dir}/plots/residuals_analysis_{target_col}.png',
683
+ dpi=300,
684
+ bbox_inches='tight'
685
+ )
686
+ plt.show()
687
+
688
+ def get_report(self) -> Dict:
689
+ """Get decomposition report"""
690
+ return self.decomposition_results
feature_selection/__init__.py ADDED
File without changes
feature_selection/feature_selector.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 11: FEATURE SELECTION
3
+ # ============================================
4
+ from typing import Dict, List, Optional, Tuple
5
+ from venv import logger
6
+ from config.config import Config
7
+
8
+ try:
9
+ import pandas as pd
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from sklearn.ensemble import RandomForestRegressor
14
+ from sklearn.decomposition import PCA
15
+ from sklearn.preprocessing import StandardScaler
16
+ print("✅ All imports working!")
17
+ except ImportError as e:
18
+ print(f"❌ Import error: {e}")
19
+
20
+ from sklearn.inspection import permutation_importance, partial_dependence
21
+ from sklearn.feature_selection import (
22
+ SelectKBest, SelectPercentile, RFE, RFECV, VarianceThreshold,
23
+ f_regression, mutual_info_regression
24
+ )
25
+
26
+ class FeatureSelector:
27
+ """Class for selecting the most important features"""
28
+
29
+ def __init__(self, config: Config):
30
+ """
31
+ Initialise feature selector
32
+
33
+ Parameters:
34
+ -----------
35
+ config : Config
36
+ Experiment configuration
37
+ """
38
+ self.config = config
39
+ self.selected_features = []
40
+ self.feature_importances = {}
41
+ self.selection_methods = {}
42
+ self.selector_objects = {}
43
+
44
+ def select(
45
+ self,
46
+ data: pd.DataFrame,
47
+ target_col: Optional[str] = None,
48
+ method: str = None,
49
+ n_features: int = None,
50
+ **kwargs
51
+ ) -> pd.DataFrame:
52
+ """
53
+ Select the most important features
54
+
55
+ Parameters:
56
+ -----------
57
+ data : pd.DataFrame
58
+ Input data
59
+ target_col : str, optional
60
+ Target variable. If None, uses configuration value.
61
+ method : str, optional
62
+ Selection method. If None, uses configuration value.
63
+ n_features : int, optional
64
+ Number of features to select. If None, uses configuration value.
65
+ **kwargs : dict
66
+ Additional parameters for method
67
+
68
+ Returns:
69
+ --------
70
+ pd.DataFrame
71
+ Data with selected features
72
+ """
73
+ logger.info("\n" + "="*80)
74
+ logger.info("FEATURE SELECTION")
75
+ logger.info("="*80)
76
+
77
+ target_col = target_col or self.config.target_column
78
+ method = method or self.config.feature_selection_method
79
+ n_features = n_features or self.config.max_features
80
+
81
+ if target_col not in data.columns:
82
+ logger.error(f"Target variable '{target_col}' not found")
83
+ return data
84
+
85
+ # Prepare data
86
+ X = data.drop(columns=[target_col]).select_dtypes(include=[np.number])
87
+ y = data[target_col]
88
+
89
+ # Remove missing values
90
+ mask = X.notna().all(axis=1) & y.notna()
91
+ X_clean = X[mask]
92
+ y_clean = y[mask]
93
+
94
+ if len(X_clean) < 10 or len(X_clean.columns) < 2:
95
+ logger.warning("Insufficient data for feature selection")
96
+ return data
97
+
98
+ logger.info(f"Selection method: {method}")
99
+ logger.info(f"Target number of features: {n_features}")
100
+ logger.info(f"Initial number of features: {len(X.columns)}")
101
+ logger.info(f"Data for selection: {len(X_clean)} records")
102
+
103
+ # Apply selection method
104
+ selected_features_list = []
105
+ feature_importance_dict = {}
106
+
107
+ if method == 'correlation':
108
+ selected_features_list, feature_importance_dict = self._correlation_selection(
109
+ X_clean, y_clean, n_features, **kwargs
110
+ )
111
+
112
+ elif method == 'mutual_info':
113
+ selected_features_list, feature_importance_dict = self._mutual_info_selection(
114
+ X_clean, y_clean, n_features, **kwargs
115
+ )
116
+
117
+ elif method == 'rf':
118
+ selected_features_list, feature_importance_dict = self._random_forest_selection(
119
+ X_clean, y_clean, n_features, **kwargs
120
+ )
121
+
122
+ elif method == 'pca':
123
+ selected_features_list, feature_importance_dict = self._pca_selection(
124
+ X_clean, y_clean, n_features, **kwargs
125
+ )
126
+
127
+ elif method == 'rfe':
128
+ selected_features_list, feature_importance_dict = self._rfe_selection(
129
+ X_clean, y_clean, n_features, **kwargs
130
+ )
131
+
132
+ elif method == 'lasso':
133
+ selected_features_list, feature_importance_dict = self._lasso_selection(
134
+ X_clean, y_clean, n_features, **kwargs
135
+ )
136
+
137
+ elif method == 'hybrid':
138
+ selected_features_list, feature_importance_dict = self._hybrid_selection(
139
+ X_clean, y_clean, n_features, **kwargs
140
+ )
141
+
142
+ else:
143
+ logger.warning(f"Method {method} not supported, using correlation")
144
+ selected_features_list, feature_importance_dict = self._correlation_selection(
145
+ X_clean, y_clean, n_features, **kwargs
146
+ )
147
+
148
+ # Save selected features
149
+ self.selected_features = selected_features_list
150
+ self.feature_importances = feature_importance_dict
151
+ self.selection_methods[method] = {
152
+ 'selected_features': selected_features_list,
153
+ 'n_features': len(selected_features_list),
154
+ 'feature_importances': feature_importance_dict
155
+ }
156
+
157
+ # Form final dataset
158
+ features_to_keep = selected_features_list + [target_col]
159
+ features_to_keep = [f for f in features_to_keep if f in data.columns]
160
+
161
+ data_selected = data[features_to_keep].copy()
162
+
163
+ logger.info(f"✓ Selected {len(selected_features_list)} features")
164
+ logger.info(f" Total features kept: {len(data_selected.columns)}")
165
+
166
+ # Visualisation
167
+ if self.config.save_plots and selected_features_list:
168
+ self._plot_feature_selection(
169
+ X_clean, y_clean, selected_features_list,
170
+ feature_importance_dict, method
171
+ )
172
+
173
+ return data_selected
174
+
175
+ def _correlation_selection(
176
+ self,
177
+ X: pd.DataFrame,
178
+ y: pd.Series,
179
+ n_features: int,
180
+ **kwargs
181
+ ) -> Tuple[List[str], Dict]:
182
+ """Feature selection based on correlation"""
183
+ # Calculate correlations with target variable
184
+ correlations = X.corrwith(y).abs().sort_values(ascending=False)
185
+
186
+ # Select top-n_features
187
+ selected_features = correlations.head(n_features).index.tolist()
188
+ feature_importance = correlations.to_dict()
189
+
190
+ return selected_features, feature_importance
191
+
192
+ def _mutual_info_selection(
193
+ self,
194
+ X: pd.DataFrame,
195
+ y: pd.Series,
196
+ n_features: int,
197
+ **kwargs
198
+ ) -> Tuple[List[str], Dict]:
199
+ """Feature selection based on mutual information"""
200
+ try:
201
+ mi_scores = mutual_info_regression(X, y, random_state=kwargs.get('random_state', 42))
202
+ mi_series = pd.Series(mi_scores, index=X.columns)
203
+ mi_series = mi_series.sort_values(ascending=False)
204
+
205
+ selected_features = mi_series.head(n_features).index.tolist()
206
+ feature_importance = mi_series.to_dict()
207
+
208
+ return selected_features, feature_importance
209
+
210
+ except Exception as e:
211
+ logger.warning(f"Mutual information selection failed: {e}, using correlation")
212
+ return self._correlation_selection(X, y, n_features, **kwargs)
213
+
214
+ def _random_forest_selection(
215
+ self,
216
+ X: pd.DataFrame,
217
+ y: pd.Series,
218
+ n_features: int,
219
+ **kwargs
220
+ ) -> Tuple[List[str], Dict]:
221
+ """Feature selection based on Random Forest"""
222
+ try:
223
+ rf = RandomForestRegressor(
224
+ n_estimators=kwargs.get('n_estimators', 100),
225
+ max_depth=kwargs.get('max_depth', None),
226
+ random_state=kwargs.get('random_state', 42),
227
+ n_jobs=self.config.n_jobs if self.config.use_multiprocessing else None
228
+ )
229
+
230
+ rf.fit(X, y)
231
+ importances = pd.Series(rf.feature_importances_, index=X.columns)
232
+ importances = importances.sort_values(ascending=False)
233
+
234
+ selected_features = importances.head(n_features).index.tolist()
235
+ feature_importance = importances.to_dict()
236
+
237
+ self.selector_objects['random_forest'] = rf
238
+
239
+ return selected_features, feature_importance
240
+
241
+ except Exception as e:
242
+ logger.warning(f"Random Forest selection failed: {e}, using correlation")
243
+ return self._correlation_selection(X, y, n_features, **kwargs)
244
+
245
+ def _pca_selection(
246
+ self,
247
+ X: pd.DataFrame,
248
+ y: pd.Series,
249
+ n_features: int,
250
+ **kwargs
251
+ ) -> Tuple[List[str], Dict]:
252
+ """Feature selection based on PCA"""
253
+ try:
254
+ # First standardise data
255
+ from sklearn.preprocessing import StandardScaler
256
+
257
+ scaler = StandardScaler()
258
+ X_scaled = scaler.fit_transform(X)
259
+
260
+ # Apply PCA
261
+ pca = PCA(n_components=min(n_features, len(X.columns)))
262
+ X_pca = pca.fit_transform(X_scaled)
263
+
264
+ # Get feature importance via absolute component values
265
+ importance = np.abs(pca.components_).sum(axis=0)
266
+ importance_series = pd.Series(importance, index=X.columns)
267
+ importance_series = importance_series.sort_values(ascending=False)
268
+
269
+ selected_features = importance_series.head(n_features).index.tolist()
270
+ feature_importance = importance_series.to_dict()
271
+
272
+ self.selector_objects['pca'] = pca
273
+ self.selector_objects['scaler'] = scaler
274
+
275
+ return selected_features, feature_importance
276
+
277
+ except Exception as e:
278
+ logger.warning(f"PCA selection failed: {e}, using correlation")
279
+ return self._correlation_selection(X, y, n_features, **kwargs)
280
+
281
+ def _rfe_selection(
282
+ self,
283
+ X: pd.DataFrame,
284
+ y: pd.Series,
285
+ n_features: int,
286
+ **kwargs
287
+ ) -> Tuple[List[str], Dict]:
288
+ """Recursive Feature Elimination"""
289
+ try:
290
+ from sklearn.feature_selection import RFE
291
+ from sklearn.linear_model import LinearRegression
292
+
293
+ estimator = LinearRegression()
294
+ rfe = RFE(
295
+ estimator=estimator,
296
+ n_features_to_select=n_features,
297
+ step=kwargs.get('step', 1)
298
+ )
299
+
300
+ rfe.fit(X, y)
301
+ selected_mask = rfe.support_
302
+ selected_features = X.columns[selected_mask].tolist()
303
+
304
+ # Feature importance via ranking
305
+ ranking = pd.Series(rfe.ranking_, index=X.columns)
306
+ feature_importance = (1 / ranking).to_dict() # Convert ranking to importance
307
+
308
+ self.selector_objects['rfe'] = rfe
309
+
310
+ return selected_features, feature_importance
311
+
312
+ except Exception as e:
313
+ logger.warning(f"RFE selection failed: {e}, using correlation")
314
+ return self._correlation_selection(X, y, n_features, **kwargs)
315
+
316
+ def _lasso_selection(
317
+ self,
318
+ X: pd.DataFrame,
319
+ y: pd.Series,
320
+ n_features: int,
321
+ **kwargs
322
+ ) -> Tuple[List[str], Dict]:
323
+ """Feature selection using Lasso"""
324
+ try:
325
+ from sklearn.linear_model import LassoCV
326
+
327
+ lasso = LassoCV(
328
+ cv=kwargs.get('cv', 5),
329
+ random_state=kwargs.get('random_state', 42),
330
+ max_iter=kwargs.get('max_iter', 1000)
331
+ )
332
+
333
+ lasso.fit(X, y)
334
+
335
+ # Features with non-zero coefficients
336
+ coefficients = pd.Series(lasso.coef_, index=X.columns)
337
+ non_zero_features = coefficients[coefficients != 0].abs().sort_values(ascending=False)
338
+
339
+ # Select top-n_features
340
+ selected_features = non_zero_features.head(n_features).index.tolist()
341
+ feature_importance = non_zero_features.to_dict()
342
+
343
+ self.selector_objects['lasso'] = lasso
344
+
345
+ return selected_features, feature_importance
346
+
347
+ except Exception as e:
348
+ logger.warning(f"Lasso selection failed: {e}, using correlation")
349
+ return self._correlation_selection(X, y, n_features, **kwargs)
350
+
351
+ def _hybrid_selection(
352
+ self,
353
+ X: pd.DataFrame,
354
+ y: pd.Series,
355
+ n_features: int,
356
+ **kwargs
357
+ ) -> Tuple[List[str], Dict]:
358
+ """Hybrid feature selection method"""
359
+ # Combine multiple methods
360
+ methods = kwargs.get('methods', ['correlation', 'mutual_info', 'rf'])
361
+ weights = kwargs.get('weights', [0.3, 0.3, 0.4])
362
+
363
+ all_importances = {}
364
+
365
+ for method, weight in zip(methods, weights):
366
+ try:
367
+ if method == 'correlation':
368
+ _, importance = self._correlation_selection(X, y, n_features, **kwargs)
369
+ elif method == 'mutual_info':
370
+ _, importance = self._mutual_info_selection(X, y, n_features, **kwargs)
371
+ elif method == 'rf':
372
+ _, importance = self._random_forest_selection(X, y, n_features, **kwargs)
373
+ else:
374
+ continue
375
+
376
+ # Normalise importances and weight them
377
+ importance_series = pd.Series(importance)
378
+ if importance_series.max() > importance_series.min():
379
+ importance_normalized = (importance_series - importance_series.min()) / \
380
+ (importance_series.max() - importance_series.min())
381
+ else:
382
+ importance_normalized = pd.Series(1, index=importance_series.index)
383
+
384
+ # Add weighted importances
385
+ for feature in importance_normalized.index:
386
+ if feature not in all_importances:
387
+ all_importances[feature] = 0
388
+ all_importances[feature] += importance_normalized[feature] * weight
389
+
390
+ except Exception as e:
391
+ logger.debug(f"Method {method} failed in hybrid selection: {e}")
392
+
393
+ # Sort by total importance
394
+ combined_importance = pd.Series(all_importances).sort_values(ascending=False)
395
+ selected_features = combined_importance.head(n_features).index.tolist()
396
+
397
+ return selected_features, combined_importance.to_dict()
398
+
399
+ def _plot_feature_selection(
400
+ self,
401
+ X: pd.DataFrame,
402
+ y: pd.Series,
403
+ selected_features: List[str],
404
+ feature_importance: Dict,
405
+ method: str
406
+ ) -> None:
407
+ """Visualise feature selection results"""
408
+ # Prepare data for visualisation
409
+ importance_series = pd.Series(feature_importance).sort_values(ascending=False)
410
+
411
+ # Limit number of features for display
412
+ display_features = importance_series.head(20)
413
+
414
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
415
+
416
+ # 1. Feature importance
417
+ y_pos = np.arange(len(display_features))
418
+ axes[0, 0].barh(y_pos, display_features.values)
419
+ axes[0, 0].set_yticks(y_pos)
420
+ axes[0, 0].set_yticklabels(display_features.index, fontsize=9)
421
+ axes[0, 0].invert_yaxis()
422
+ axes[0, 0].set_xlabel('Importance')
423
+ axes[0, 0].set_title(f'Top-{len(display_features)} features by importance ({method})')
424
+ axes[0, 0].grid(True, alpha=0.3, axis='x')
425
+
426
+ # 2. Cumulative importance
427
+ cumulative_importance = importance_series.cumsum() / importance_series.sum()
428
+ axes[0, 1].plot(range(1, len(cumulative_importance) + 1), cumulative_importance.values)
429
+ axes[0, 1].axhline(y=0.8, color='r', linestyle='--', alpha=0.7, label='80% importance')
430
+ axes[0, 1].axhline(y=0.9, color='orange', linestyle='--', alpha=0.7, label='90% importance')
431
+ axes[0, 1].set_xlabel('Number of features')
432
+ axes[0, 1].set_ylabel('Cumulative importance')
433
+ axes[0, 1].set_title('Cumulative feature importance')
434
+ axes[0, 1].legend()
435
+ axes[0, 1].grid(True, alpha=0.3)
436
+
437
+ # 3. Correlation matrix of selected features
438
+ if len(selected_features) > 1:
439
+ selected_X = X[selected_features]
440
+ corr_matrix = selected_X.corr()
441
+
442
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
443
+ sns.heatmap(
444
+ corr_matrix,
445
+ annot=True,
446
+ fmt='.2f',
447
+ cmap='coolwarm',
448
+ center=0,
449
+ square=True,
450
+ mask=mask,
451
+ cbar_kws={'shrink': 0.8},
452
+ ax=axes[1, 0]
453
+ )
454
+ axes[1, 0].set_title(f'Correlation of selected features ({len(selected_features)})')
455
+
456
+ # 4. Importance distribution
457
+ axes[1, 1].hist(importance_series.values, bins=30, edgecolor='black', alpha=0.7)
458
+ axes[1, 1].set_xlabel('Feature importance')
459
+ axes[1, 1].set_ylabel('Frequency')
460
+ axes[1, 1].set_title('Feature importance distribution')
461
+ axes[1, 1].grid(True, alpha=0.3)
462
+
463
+ plt.suptitle(f'Feature selection results using {method} method', fontsize=14)
464
+ plt.tight_layout()
465
+ plt.savefig(
466
+ f'{self.config.results_dir}/plots/feature_selection_{method}.png',
467
+ dpi=300,
468
+ bbox_inches='tight'
469
+ )
470
+ plt.show()
471
+
472
+ def get_report(self) -> Dict:
473
+ """Get feature selection report"""
474
+ return {
475
+ 'selected_features': self.selected_features,
476
+ 'feature_importances': self.feature_importances,
477
+ 'selection_methods': self.selection_methods
478
+ }
features/__init__.py ADDED
File without changes
features/feature_engineer.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 5: FEATURE ENGINEER
3
+ # ============================================
4
+ from typing import Dict, List, Optional
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+
13
+ class FeatureEngineer:
14
+ """Class for creating new features for time series"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise feature engineer
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.created_features = []
27
+ self.feature_info = {}
28
+ self.feature_importances = {}
29
+ self.transforms_applied = {}
30
+
31
+ def create_all_features(
32
+ self,
33
+ data: pd.DataFrame,
34
+ target_col: Optional[str] = None
35
+ ) -> pd.DataFrame:
36
+ """
37
+ Create all types of features
38
+
39
+ Parameters:
40
+ -----------
41
+ data : pd.DataFrame
42
+ Input data
43
+ target_col : str, optional
44
+ Target variable. If None, uses configuration value.
45
+
46
+ Returns:
47
+ --------
48
+ pd.DataFrame
49
+ Data with all features
50
+ """
51
+ logger.info("\n" + "="*80)
52
+ logger.info("CREATING FEATURES FOR TIME SERIES")
53
+ logger.info("="*80)
54
+
55
+ target_col = target_col or self.config.target_column
56
+ initial_features = len(data.columns)
57
+ initial_rows = len(data)
58
+
59
+ # Check and save index
60
+ original_index = data.index
61
+ index_is_datetime = isinstance(original_index, pd.DatetimeIndex)
62
+
63
+ logger.info(f"Initial number of features: {initial_features}")
64
+ logger.info(f"Initial number of rows: {initial_rows}")
65
+ logger.info(f"Index is DatetimeIndex: {index_is_datetime}")
66
+
67
+ # If index not DatetimeIndex but 'date' column exists
68
+ if not index_is_datetime and 'date' in data.columns:
69
+ logger.info("Attempting to set DatetimeIndex from 'date' column")
70
+ try:
71
+ data = data.set_index('date')
72
+ if isinstance(data.index, pd.DatetimeIndex):
73
+ index_is_datetime = True
74
+ original_index = data.index
75
+ logger.info("✓ DatetimeIndex set from 'date' column")
76
+ else:
77
+ logger.warning("Failed to set DatetimeIndex")
78
+ except Exception as e:
79
+ logger.warning(f"Error setting DatetimeIndex: {e}")
80
+
81
+ # Save data copy for index restoration later
82
+ data_processed = data.copy()
83
+
84
+ # 1. Create basic temporal features (if date exists)
85
+ if index_is_datetime:
86
+ logger.info("\n1. BASIC TEMPORAL FEATURES")
87
+ data_processed = self.create_temporal_features(data_processed)
88
+ else:
89
+ logger.info("\n1. BASIC TEMPORAL FEATURES: skipped (no DatetimeIndex)")
90
+
91
+ # 2. Create statistical features
92
+ logger.info("\n2. STATISTICAL FEATURES")
93
+ data_processed = self.create_statistical_features(data_processed, target_col)
94
+
95
+ # 3. Create rolling features
96
+ logger.info("\n3. ROLLING FEATURES")
97
+ data_processed = self.create_rolling_features(data_processed, target_col)
98
+
99
+ # 4. Create lag features (limited quantity)
100
+ logger.info("\n4. LAG FEATURES")
101
+ data_processed = self.create_lag_features(data_processed, target_col)
102
+
103
+ # 5. Create interaction features
104
+ logger.info("\n5. INTERACTION FEATURES")
105
+ data_processed = self.create_interaction_features(data_processed, target_col)
106
+
107
+ # 6. Create spectral features (only if sufficient data)
108
+ logger.info("\n6. SPECTRAL FEATURES")
109
+ if len(data_processed) > 100:
110
+ data_processed = self.create_spectral_features(data_processed, target_col)
111
+ else:
112
+ logger.info(" Skipped: insufficient data")
113
+
114
+ # 7. Create decomposition features (only if sufficient data and date exists)
115
+ logger.info("\n7. DECOMPOSITION FEATURES")
116
+ if len(data_processed) > 365 and index_is_datetime:
117
+ data_processed = self.create_decomposition_features(data_processed, target_col)
118
+ else:
119
+ logger.info(" Skipped: insufficient data or no DatetimeIndex")
120
+
121
+ # Remove rows with NaN that appeared due to lags and differences
122
+ rows_before_nan = len(data_processed)
123
+ data_processed = data_processed.dropna()
124
+ rows_after_nan = len(data_processed)
125
+ removed_rows = rows_before_nan - rows_after_nan
126
+
127
+ # Remove constant features
128
+ constant_features = []
129
+ for col in data_processed.columns:
130
+ if data_processed[col].nunique() <= 1:
131
+ constant_features.append(col)
132
+
133
+ if constant_features:
134
+ logger.info(f"\nRemoving constant features: {len(constant_features)} found")
135
+ for feat in constant_features[:10]:
136
+ logger.info(f" - {feat}")
137
+ if len(constant_features) > 10:
138
+ logger.info(f" ... and {len(constant_features) - 10} more features")
139
+
140
+ data_processed = data_processed.drop(columns=constant_features)
141
+ # Update created features list
142
+ self.created_features = [f for f in self.created_features if f not in constant_features]
143
+
144
+ # Save information
145
+ self.feature_info = {
146
+ 'initial_features': initial_features,
147
+ 'final_features': len(data_processed.columns),
148
+ 'features_created': len(self.created_features),
149
+ 'initial_rows': initial_rows,
150
+ 'final_rows': len(data_processed),
151
+ 'removed_rows': removed_rows,
152
+ 'constant_features_removed': len(constant_features),
153
+ 'created_features_list': self.created_features,
154
+ 'feature_categories': self.get_feature_categories()
155
+ }
156
+
157
+ logger.info(f"\nFeature creation summary:")
158
+ logger.info(f" Initial number of features: {initial_features}")
159
+ logger.info(f" Final number of features: {len(data_processed.columns)}")
160
+ logger.info(f" New features created: {len(self.created_features)}")
161
+ logger.info(f" Initial number of rows: {initial_rows}")
162
+ logger.info(f" Final number of rows: {len(data_processed)}")
163
+ logger.info(f" Rows removed due to NaN: {removed_rows}")
164
+ logger.info(f" Constant features removed: {len(constant_features)}")
165
+
166
+ return data_processed
167
+
168
+ def create_temporal_features(self, data: pd.DataFrame) -> pd.DataFrame:
169
+ """
170
+ Create temporal features
171
+
172
+ Parameters:
173
+ -----------
174
+ data : pd.DataFrame
175
+ Input data
176
+
177
+ Returns:
178
+ --------
179
+ pd.DataFrame
180
+ Data with temporal features
181
+ """
182
+ data_processed = data.copy()
183
+
184
+ if not isinstance(data_processed.index, pd.DatetimeIndex):
185
+ logger.warning("Temporal features not created: index not DatetimeIndex")
186
+ return data_processed
187
+
188
+ try:
189
+ # Basic temporal features
190
+ data_processed['year'] = data_processed.index.year
191
+ data_processed['month'] = data_processed.index.month
192
+ data_processed['day'] = data_processed.index.day
193
+ data_processed['dayofyear'] = data_processed.index.dayofyear
194
+ data_processed['dayofweek'] = data_processed.index.dayofweek
195
+ data_processed['weekofyear'] = data_processed.index.isocalendar().week.astype(int)
196
+ data_processed['quarter'] = data_processed.index.quarter
197
+ data_processed['is_weekend'] = data_processed['dayofweek'].isin([5, 6]).astype(int)
198
+
199
+ # Cyclic features for seasonality
200
+ data_processed['month_sin'] = np.sin(2 * np.pi * data_processed['month'] / 12)
201
+ data_processed['month_cos'] = np.cos(2 * np.pi * data_processed['month'] / 12)
202
+ data_processed['dayofyear_sin'] = np.sin(2 * np.pi * data_processed['dayofyear'] / 365.25)
203
+ data_processed['dayofyear_cos'] = np.cos(2 * np.pi * data_processed['dayofyear'] / 365.25)
204
+ data_processed['dayofweek_sin'] = np.sin(2 * np.pi * data_processed['dayofweek'] / 7)
205
+ data_processed['dayofweek_cos'] = np.cos(2 * np.pi * data_processed['dayofweek'] / 7)
206
+
207
+ # Time in days from start (relative features)
208
+ min_date = data_processed.index.min()
209
+ data_processed['days_from_start'] = (data_processed.index - min_date).days
210
+
211
+ # Register created features
212
+ temporal_features = ['year', 'month', 'day', 'dayofyear', 'dayofweek',
213
+ 'weekofyear', 'quarter', 'is_weekend', 'month_sin',
214
+ 'month_cos', 'dayofyear_sin', 'dayofyear_cos',
215
+ 'dayofweek_sin', 'dayofweek_cos', 'days_from_start']
216
+
217
+ self.created_features.extend([f for f in temporal_features if f not in self.created_features])
218
+
219
+ logger.info(f"✓ Created {len(temporal_features)} temporal features")
220
+
221
+ except Exception as e:
222
+ logger.warning(f"Error creating temporal features: {e}")
223
+
224
+ return data_processed
225
+
226
+ def create_statistical_features(
227
+ self,
228
+ data: pd.DataFrame,
229
+ target_col: str
230
+ ) -> pd.DataFrame:
231
+ """
232
+ Create statistical features
233
+
234
+ Parameters:
235
+ -----------
236
+ data : pd.DataFrame
237
+ Input data
238
+ target_col : str
239
+ Target variable
240
+
241
+ Returns:
242
+ --------
243
+ pd.DataFrame
244
+ Data with statistical features
245
+ """
246
+ data_processed = data.copy()
247
+
248
+ if target_col not in data_processed.columns:
249
+ logger.warning(f"Target variable '{target_col}' not found")
250
+ return data_processed
251
+
252
+ # Only if we have year data
253
+ if 'year' in data_processed.columns:
254
+ # Yearly statistics
255
+ try:
256
+ yearly_stats = data_processed.groupby('year')[target_col].agg([
257
+ 'mean', 'std', 'min', 'max', 'median'
258
+ ])
259
+ yearly_stats.columns = [f'{target_col}_yearly_{col}' for col in yearly_stats.columns]
260
+ data_processed = data_processed.merge(yearly_stats, on='year', how='left')
261
+
262
+ # Add created features to list
263
+ for col in yearly_stats.columns:
264
+ self.created_features.append(col)
265
+ except Exception as e:
266
+ logger.debug(f"Yearly statistics not created: {e}")
267
+
268
+ # Normalised features (only if there is variation)
269
+ std_val = data_processed[target_col].std()
270
+ if std_val > 0:
271
+ data_processed[f'{target_col}_zscore'] = (data_processed[target_col] - data_processed[target_col].mean()) / std_val
272
+ self.created_features.append(f'{target_col}_zscore')
273
+
274
+ # Features based on percentiles (binary features)
275
+ try:
276
+ for p in [0.25, 0.5, 0.75]:
277
+ quantile_val = data_processed[target_col].quantile(p)
278
+ data_processed[f'{target_col}_above_p{int(p*100)}'] = (data_processed[target_col] > quantile_val).astype(int)
279
+ self.created_features.append(f'{target_col}_above_p{int(p*100)}')
280
+ except Exception as e:
281
+ logger.debug(f"Quantile features not created: {e}")
282
+
283
+ logger.info(f"✓ Statistical features created: {len([c for c in data_processed.columns if c not in data.columns])}")
284
+ return data_processed
285
+
286
+ def create_rolling_features(
287
+ self,
288
+ data: pd.DataFrame,
289
+ target_col: str
290
+ ) -> pd.DataFrame:
291
+ """
292
+ Create rolling statistics
293
+
294
+ Parameters:
295
+ -----------
296
+ data : pd.DataFrame
297
+ Input data
298
+ target_col : str
299
+ Target variable
300
+
301
+ Returns:
302
+ --------
303
+ pd.DataFrame
304
+ Data with rolling features
305
+ """
306
+ data_processed = data.copy()
307
+
308
+ if target_col not in data_processed.columns:
309
+ logger.warning(f"Target variable '{target_col}' not found")
310
+ return data_processed
311
+
312
+ # Use only main windows from configuration
313
+ windows = [w for w in self.config.rolling_windows if w < len(data_processed) // 2]
314
+
315
+ for window in windows:
316
+ try:
317
+ # Basic statistics
318
+ data_processed[f'{target_col}_rolling_mean_{window}'] = data_processed[target_col].rolling(
319
+ window=window, min_periods=max(1, window//4), center=True
320
+ ).mean()
321
+
322
+ data_processed[f'{target_col}_rolling_std_{window}'] = data_processed[target_col].rolling(
323
+ window=window, min_periods=max(1, window//4), center=True
324
+ ).std()
325
+
326
+ data_processed[f'{target_col}_rolling_min_{window}'] = data_processed[target_col].rolling(
327
+ window=window, min_periods=max(1, window//4), center=True
328
+ ).min()
329
+
330
+ data_processed[f'{target_col}_rolling_max_{window}'] = data_processed[target_col].rolling(
331
+ window=window, min_periods=max(1, window//4), center=True
332
+ ).max()
333
+
334
+ self.created_features.extend([
335
+ f'{target_col}_rolling_mean_{window}',
336
+ f'{target_col}_rolling_std_{window}',
337
+ f'{target_col}_rolling_min_{window}',
338
+ f'{target_col}_rolling_max_{window}'
339
+ ])
340
+ except Exception as e:
341
+ logger.debug(f"Rolling features for window {window} not created: {e}")
342
+ continue
343
+
344
+ logger.info(f"✓ Rolling features created: {len([c for c in data_processed.columns if 'rolling' in c and c not in data.columns])}")
345
+ return data_processed
346
+
347
+ def create_lag_features(
348
+ self,
349
+ data: pd.DataFrame,
350
+ target_col: str
351
+ ) -> pd.DataFrame:
352
+ """
353
+ Create lag features
354
+
355
+ Parameters:
356
+ -----------
357
+ data : pd.DataFrame
358
+ Input data
359
+ target_col : str
360
+ Target variable
361
+
362
+ Returns:
363
+ --------
364
+ pd.DataFrame
365
+ Data with lag features
366
+ """
367
+ data_processed = data.copy()
368
+
369
+ if target_col not in data_processed.columns:
370
+ logger.warning(f"Target variable '{target_col}' not found")
371
+ return data_processed
372
+
373
+ # Limited number of lags
374
+ max_lags = min(self.config.max_lags, 7) # Maximum 7 lags
375
+
376
+ for lag in [1, 2, 3, 7, 14, 30]:
377
+ if lag <= max_lags:
378
+ data_processed[f'{target_col}_lag_{lag}'] = data_processed[target_col].shift(lag)
379
+ self.created_features.append(f'{target_col}_lag_{lag}')
380
+
381
+ # Seasonal lags (only if sufficient data)
382
+ if len(data_processed) > 365:
383
+ try:
384
+ data_processed[f'{target_col}_seasonal_lag_365'] = data_processed[target_col].shift(365)
385
+ self.created_features.append(f'{target_col}_seasonal_lag_365')
386
+ except Exception as e:
387
+ logger.debug(f"Seasonal lag not created: {e}")
388
+
389
+ # Differences (stationarity)
390
+ data_processed[f'{target_col}_diff_1'] = data_processed[target_col].diff(1)
391
+ self.created_features.append(f'{target_col}_diff_1')
392
+
393
+ if len(data_processed) > 7:
394
+ data_processed[f'{target_col}_diff_7'] = data_processed[target_col].diff(7)
395
+ self.created_features.append(f'{target_col}_diff_7')
396
+
397
+ logger.info(f"✓ Lag features created: {len([c for c in data_processed.columns if ('lag' in c or 'diff' in c) and c not in data.columns])}")
398
+ return data_processed
399
+
400
+ def create_interaction_features(
401
+ self,
402
+ data: pd.DataFrame,
403
+ target_col: str
404
+ ) -> pd.DataFrame:
405
+ """
406
+ Create interaction features
407
+
408
+ Parameters:
409
+ -----------
410
+ data : pd.DataFrame
411
+ Input data
412
+ target_col : str
413
+ Target variable
414
+
415
+ Returns:
416
+ --------
417
+ pd.DataFrame
418
+ Data with interaction features
419
+ """
420
+ data_processed = data.copy()
421
+
422
+ if target_col not in data_processed.columns:
423
+ logger.warning(f"Target variable '{target_col}' not found")
424
+ return data_processed
425
+
426
+ # Interactions with temperature (only if data exists)
427
+ temp_cols = ['tavg', 'tmin', 'tmax']
428
+ available_temp_cols = [col for col in temp_cols if col in data_processed.columns]
429
+
430
+ for temp_col in available_temp_cols:
431
+ try:
432
+ # Avoid division by zero
433
+ temp_data = data_processed[temp_col].replace(0, np.nan)
434
+ if temp_data.notna().all() and (temp_data != 0).all():
435
+ data_processed[f'{target_col}_{temp_col}_ratio'] = data_processed[target_col] / temp_data
436
+ self.created_features.append(f'{target_col}_{temp_col}_ratio')
437
+
438
+ # Product
439
+ data_processed[f'{target_col}_{temp_col}_product'] = data_processed[target_col] * temp_data
440
+ self.created_features.append(f'{target_col}_{temp_col}_product')
441
+ except Exception as e:
442
+ logger.debug(f"Interaction feature with {temp_col} not created: {e}")
443
+
444
+ # Interaction with water level
445
+ if 'urovenvoda' in data_processed.columns:
446
+ try:
447
+ uroven_data = data_processed['urovenvoda'].replace(0, np.nan)
448
+ if uroven_data.notna().all() and (uroven_data != 0).all():
449
+ data_processed[f'{target_col}_urovenvoda_ratio'] = data_processed[target_col] / uroven_data
450
+ self.created_features.append(f'{target_col}_urovenvoda_ratio')
451
+ except Exception as e:
452
+ logger.debug(f"Interaction feature with urovenvoda not created: {e}")
453
+
454
+ logger.info(f"✓ Interaction features created: {len([c for c in data_processed.columns if ('ratio' in c or 'product' in c) and c not in data.columns])}")
455
+ return data_processed
456
+
457
+ def create_spectral_features(
458
+ self,
459
+ data: pd.DataFrame,
460
+ target_col: str
461
+ ) -> pd.DataFrame:
462
+ """
463
+ Create spectral features
464
+
465
+ Parameters:
466
+ -----------
467
+ data : pd.DataFrame
468
+ Input data
469
+ target_col : str
470
+ Target variable
471
+
472
+ Returns:
473
+ --------
474
+ pd.DataFrame
475
+ Data with spectral features
476
+ """
477
+ data_processed = data.copy()
478
+
479
+ if target_col not in data_processed.columns:
480
+ logger.warning(f"Target variable '{target_col}' not found")
481
+ return data_processed
482
+
483
+ if len(data_processed) < 100:
484
+ logger.info("Insufficient data for creating spectral features")
485
+ return data_processed
486
+
487
+ try:
488
+ # Fast Fourier Transform
489
+ series = data_processed[target_col].dropna().values
490
+
491
+ if len(series) > 50:
492
+ # Calculate periodogram
493
+ from scipy.signal import periodogram
494
+ freqs, psd = periodogram(series, fs=1.0)
495
+
496
+ # Find dominant frequencies
497
+ if len(psd) > 3:
498
+ # Top-3 frequencies by power
499
+ top_indices = np.argsort(psd)[-3:][::-1]
500
+
501
+ for i, idx in enumerate(top_indices, 1):
502
+ if idx < len(freqs):
503
+ freq = freqs[idx]
504
+ if freq > 0:
505
+ period = 1 / freq
506
+ data_processed[f'{target_col}_dominant_period_{i}'] = period
507
+ self.created_features.append(f'{target_col}_dominant_period_{i}')
508
+
509
+ except Exception as e:
510
+ logger.debug(f"Spectral features creation failed: {e}")
511
+
512
+ return data_processed
513
+
514
+ def create_decomposition_features(
515
+ self,
516
+ data: pd.DataFrame,
517
+ target_col: str
518
+ ) -> pd.DataFrame:
519
+ """
520
+ Create features based on decomposition
521
+
522
+ Parameters:
523
+ -----------
524
+ data : pd.DataFrame
525
+ Input data
526
+ target_col : str
527
+ Target variable
528
+
529
+ Returns:
530
+ --------
531
+ pd.DataFrame
532
+ Data with decomposition features
533
+ """
534
+ data_processed = data.copy()
535
+
536
+ if target_col not in data_processed.columns:
537
+ logger.warning(f"Target variable '{target_col}' not found")
538
+ return data_processed
539
+
540
+ if len(data_processed) < 365:
541
+ logger.info("Insufficient data for decomposition")
542
+ return data_processed
543
+
544
+ try:
545
+ # Check for date presence
546
+ if isinstance(data_processed.index, pd.DatetimeIndex):
547
+ # STL decomposition
548
+ if len(data_processed) > 730: # Need at least 2 years for yearly seasonality
549
+ try:
550
+ from statsmodels.tsa.seasonal import STL
551
+
552
+ # STL decomposition
553
+ stl = STL(
554
+ data_processed[target_col].fillna(method='ffill'),
555
+ period=365,
556
+ robust=True
557
+ )
558
+ result = stl.fit()
559
+
560
+ # Add components
561
+ data_processed[f'{target_col}_trend'] = result.trend
562
+ data_processed[f'{target_col}_seasonal'] = result.seasonal
563
+ data_processed[f'{target_col}_residual'] = result.resid
564
+
565
+ self.created_features.extend([
566
+ f'{target_col}_trend',
567
+ f'{target_col}_seasonal',
568
+ f'{target_col}_residual'
569
+ ])
570
+
571
+ logger.info("✓ STL decomposition successful")
572
+
573
+ except Exception as e:
574
+ logger.debug(f"STL decomposition failed: {e}")
575
+ # Simple seasonal decomposition
576
+ try:
577
+ from statsmodels.tsa.seasonal import seasonal_decompose
578
+
579
+ decomposition = seasonal_decompose(
580
+ data_processed[target_col].fillna(method='ffill'),
581
+ model='additive',
582
+ period=365,
583
+ extrapolate_trend='freq'
584
+ )
585
+
586
+ data_processed[f'{target_col}_trend'] = decomposition.trend
587
+ data_processed[f'{target_col}_seasonal'] = decomposition.seasonal
588
+
589
+ self.created_features.extend([
590
+ f'{target_col}_trend',
591
+ f'{target_col}_seasonal'
592
+ ])
593
+
594
+ logger.info("✓ Seasonal decomposition successful")
595
+ except Exception as e2:
596
+ logger.debug(f"Seasonal decomposition failed: {e2}")
597
+
598
+ except Exception as e:
599
+ logger.debug(f"Decomposition features creation failed: {e}")
600
+
601
+ return data_processed
602
+
603
+ def get_feature_categories(self) -> Dict[str, List[str]]:
604
+ """Get features by categories"""
605
+ categories = {
606
+ 'temporal': [],
607
+ 'statistical': [],
608
+ 'rolling': [],
609
+ 'lag': [],
610
+ 'interaction': [],
611
+ 'spectral': [],
612
+ 'decomposition': [],
613
+ 'binary': []
614
+ }
615
+
616
+ for feature in self.created_features:
617
+ if any(keyword in feature for keyword in ['year', 'month', 'day', 'week', 'quarter', 'sin', 'cos', 'is_weekend']):
618
+ categories['temporal'].append(feature)
619
+ elif any(keyword in feature for keyword in ['zscore', 'above_p', 'yearly_']):
620
+ if 'above_p' in feature:
621
+ categories['binary'].append(feature)
622
+ else:
623
+ categories['statistical'].append(feature)
624
+ elif 'rolling' in feature:
625
+ categories['rolling'].append(feature)
626
+ elif any(keyword in feature for keyword in ['lag', 'diff']):
627
+ categories['lag'].append(feature)
628
+ elif 'ratio' in feature or 'product' in feature:
629
+ categories['interaction'].append(feature)
630
+ elif 'dominant' in feature:
631
+ categories['spectral'].append(feature)
632
+ elif any(keyword in feature for keyword in ['trend', 'seasonal', 'residual']):
633
+ categories['decomposition'].append(feature)
634
+
635
+ # Remove empty categories
636
+ categories = {k: v for k, v in categories.items() if v}
637
+
638
+ return categories
missing_values/__init__.py ADDED
File without changes
missing_values/missing_analyzer.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 3: MISSING VALUE ANALYSER
3
+ # ============================================
4
+ from typing import Dict, Tuple
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+ from scipy.interpolate import interp1d
9
+ from statsmodels.tsa.seasonal import seasonal_decompose, STL
10
+ try:
11
+ import pandas as pd
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ print("✅ All imports working!")
15
+ except ImportError as e:
16
+ print(f"❌ Import error: {e}")
17
+
18
+ class MissingValueAnalyser:
19
+ """Class for analysing and handling missing values"""
20
+
21
+ def __init__(self, config: Config):
22
+ """
23
+ Initialise missing value analyser
24
+
25
+ Parameters:
26
+ -----------
27
+ config : Config
28
+ Experiment configuration
29
+ """
30
+ self.config = config
31
+ self.missing_info = {}
32
+ self.handling_methods = {}
33
+ self.imputers = {}
34
+ self.missing_patterns = {}
35
+
36
+ def analyse(
37
+ self,
38
+ data: pd.DataFrame,
39
+ detailed: bool = True
40
+ ) -> Dict:
41
+ """
42
+ Analyse missing values in data
43
+
44
+ Parameters:
45
+ -----------
46
+ data : pd.DataFrame
47
+ Input data
48
+ detailed : bool
49
+ Whether to perform detailed analysis
50
+
51
+ Returns:
52
+ --------
53
+ Dict
54
+ Information about missing values
55
+ """
56
+ logger.info("\n" + "="*80)
57
+ logger.info("MISSING VALUE ANALYSIS")
58
+ logger.info("="*80)
59
+
60
+ # Calculate missing values
61
+ missing_total = data.isnull().sum()
62
+ missing_percent = (missing_total / len(data)) * 100
63
+
64
+ missing_df = pd.DataFrame({
65
+ 'missing_count': missing_total,
66
+ 'missing_percent': missing_percent,
67
+ 'dtype': data.dtypes.astype(str)
68
+ })
69
+
70
+ # Detailed analysis
71
+ if detailed:
72
+ self._detailed_missing_analysis(data, missing_df)
73
+
74
+ # Save information
75
+ self.missing_info = {
76
+ 'summary': {
77
+ col: {
78
+ 'missing_count': int(missing_df.loc[col, 'missing_count']),
79
+ 'missing_percent': float(missing_df.loc[col, 'missing_percent']),
80
+ 'dtype': missing_df.loc[col, 'dtype']
81
+ }
82
+ for col in missing_df.index
83
+ },
84
+ 'overall': {
85
+ 'total_missing': int(missing_total.sum()),
86
+ 'total_rows': int(len(data)),
87
+ 'total_cells': int(data.size),
88
+ 'overall_missing_percentage': float(missing_total.sum() / data.size * 100),
89
+ 'rows_with_any_missing': int(data.isnull().any(axis=1).sum()),
90
+ 'rows_all_missing': int(data.isnull().all(axis=1).sum()),
91
+ 'columns_with_missing': missing_df[missing_df['missing_count'] > 0].index.tolist(),
92
+ 'columns_all_missing': missing_df[missing_df['missing_count'] == len(data)].index.tolist()
93
+ }
94
+ }
95
+
96
+ # Visualisation
97
+ if self.config.save_plots:
98
+ self._plot_missing_values(data, missing_df)
99
+
100
+ # Output results
101
+ self._log_missing_summary(missing_df)
102
+
103
+ return self.missing_info
104
+
105
+ def _detailed_missing_analysis(
106
+ self,
107
+ data: pd.DataFrame,
108
+ missing_df: pd.DataFrame
109
+ ) -> None:
110
+ """Detailed missing value analysis"""
111
+ # Analyse missing patterns
112
+ missing_matrix = data.isnull()
113
+
114
+ # Row missing patterns
115
+ row_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=1)
116
+ row_pattern_counts = row_patterns.value_counts().head(10)
117
+
118
+ # Column missing patterns
119
+ col_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=0)
120
+ col_pattern_counts = col_patterns.value_counts().head(10)
121
+
122
+ # Time-based missing patterns analysis
123
+ time_patterns = {}
124
+ if isinstance(data.index, pd.DatetimeIndex):
125
+ # Missing values by time
126
+ time_missing = data.isnull().resample('M').sum()
127
+ time_patterns['monthly_missing'] = time_missing.sum(axis=1).to_dict()
128
+
129
+ # Missing values by day of week
130
+ data_with_dow = data.copy()
131
+ data_with_dow['dayofweek'] = data.index.dayofweek
132
+ dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
133
+ time_patterns['dayofweek_missing'] = dow_missing.to_dict()
134
+
135
+ self.missing_patterns = {
136
+ 'row_patterns': row_pattern_counts.to_dict(),
137
+ 'col_patterns': col_pattern_counts.to_dict(),
138
+ 'time_patterns': time_patterns,
139
+ 'missing_correlation': missing_matrix.corr().to_dict() # Missing value correlation
140
+ }
141
+
142
+ logger.debug(f"Found {len(row_pattern_counts)} unique row missing patterns")
143
+ logger.debug(f"Found {len(col_pattern_counts)} unique column missing patterns")
144
+
145
+ def _plot_missing_values(
146
+ self,
147
+ data: pd.DataFrame,
148
+ missing_df: pd.DataFrame
149
+ ) -> None:
150
+ """Visualise missing values"""
151
+ fig, axes = plt.subplots(3, 2, figsize=(16, 12))
152
+
153
+ # 1. Missing percentage histogram
154
+ axes[0, 0].barh(
155
+ missing_df.index,
156
+ missing_df['missing_percent']
157
+ )
158
+ axes[0, 0].axvline(self.config.missing_threshold, color='red', linestyle='--')
159
+ axes[0, 0].set_title('Missing Percentage by Column')
160
+ axes[0, 0].set_xlabel('Missing Percentage (%)')
161
+ axes[0, 0].set_ylabel('Columns')
162
+ axes[0, 0].grid(True, alpha=0.3)
163
+
164
+ # 2. Missing values heatmap
165
+ missing_matrix = data.isnull()
166
+ axes[0, 1].imshow(
167
+ missing_matrix.T if len(data) > 1000 else missing_matrix.T[:1000],
168
+ aspect='auto',
169
+ cmap='binary',
170
+ interpolation='none'
171
+ )
172
+ axes[0, 1].set_title('Missing Values Matrix')
173
+ axes[0, 1].set_xlabel('Observation Index')
174
+ axes[0, 1].set_ylabel('Variables')
175
+ axes[0, 1].set_yticks(range(len(data.columns)))
176
+ axes[0, 1].set_yticklabels(data.columns, fontsize=8)
177
+
178
+ # 3. Missing values over time (if time series)
179
+ if isinstance(data.index, pd.DatetimeIndex):
180
+ time_missing = data.isnull().resample('M').sum()
181
+
182
+ axes[1, 0].plot(time_missing.sum(axis=1))
183
+ axes[1, 0].set_title('Missing Values by Month')
184
+ axes[1, 0].set_xlabel('Date')
185
+ axes[1, 0].set_ylabel('Number of Missing Values')
186
+ axes[1, 0].grid(True, alpha=0.3)
187
+
188
+ # 4. Missing values by day of week
189
+ data_with_dow = data.copy()
190
+ data_with_dow['dayofweek'] = data.index.dayofweek
191
+ dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
192
+ dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
193
+
194
+ axes[1, 1].bar(range(7), dow_missing)
195
+ axes[1, 1].set_title('Missing Values by Day of Week')
196
+ axes[1, 1].set_xlabel('Day of Week')
197
+ axes[1, 1].set_ylabel('Number of Missing Values')
198
+ axes[1, 1].set_xticks(range(7))
199
+ axes[1, 1].set_xticklabels(dow_names)
200
+ axes[1, 1].grid(True, alpha=0.3)
201
+
202
+ # 5. Missing value correlation
203
+ missing_corr = data.isnull().corr()
204
+ im = axes[2, 0].imshow(
205
+ missing_corr,
206
+ cmap='coolwarm',
207
+ vmin=-1,
208
+ vmax=1,
209
+ aspect='auto'
210
+ )
211
+ axes[2, 0].set_title('Missing Value Correlation Between Variables')
212
+ axes[2, 0].set_xlabel('Variables')
213
+ axes[2, 0].set_ylabel('Variables')
214
+ plt.colorbar(im, ax=axes[2, 0])
215
+
216
+ # 6. Cumulative missing sum
217
+ cumulative_missing = data.isnull().cumsum()
218
+ for col in data.columns[:5]: # First 5 columns
219
+ if data[col].isnull().any():
220
+ axes[2, 1].plot(
221
+ cumulative_missing.index,
222
+ cumulative_missing[col],
223
+ label=col[:20]
224
+ )
225
+ axes[2, 1].set_title('Cumulative Missing Values')
226
+ axes[2, 1].set_xlabel('Time/Index')
227
+ axes[2, 1].set_ylabel('Cumulative Missing')
228
+ axes[2, 1].legend(fontsize=8)
229
+ axes[2, 1].grid(True, alpha=0.3)
230
+
231
+ plt.tight_layout()
232
+ plt.savefig(
233
+ f'{self.config.results_dir}/plots/missing_values_analysis.png',
234
+ dpi=300,
235
+ bbox_inches='tight'
236
+ )
237
+ plt.show()
238
+
239
+ def _log_missing_summary(self, missing_df: pd.DataFrame) -> None:
240
+ """Log missing value summary"""
241
+ missing_columns = missing_df[missing_df['missing_count'] > 0]
242
+
243
+ if len(missing_columns) > 0:
244
+ logger.info("MISSING VALUES FOUND:")
245
+ logger.info("-" * 50)
246
+ logger.info(f"Total missing values: {self.missing_info['overall']['total_missing']}")
247
+ logger.info(f"Overall missing percentage: {self.missing_info['overall']['overall_missing_percentage']:.2f}%")
248
+ logger.info(f"Rows with missing values: {self.missing_info['overall']['rows_with_any_missing']}")
249
+ logger.info(f"Columns with missing values: {len(self.missing_info['overall']['columns_with_missing'])}")
250
+
251
+ logger.info("\nTop-10 columns by missing values:")
252
+ top_missing = missing_df.nlargest(10, 'missing_percent')
253
+ for idx, (col, row) in enumerate(top_missing.iterrows(), 1):
254
+ logger.info(f" {idx:2d}. {col}: {int(row['missing_count'])} missing ({row['missing_percent']:.2f}%)")
255
+ else:
256
+ logger.info("✓ No missing values found")
257
+
258
+ def handle(
259
+ self,
260
+ data: pd.DataFrame,
261
+ method: str = 'interpolate',
262
+ strategy: str = 'columnwise',
263
+ **kwargs
264
+ ) -> pd.DataFrame:
265
+ """
266
+ Handle missing values
267
+
268
+ Parameters:
269
+ -----------
270
+ data : pd.DataFrame
271
+ Input data
272
+ method : str
273
+ Handling method: 'interpolate', 'ffill', 'bfill', 'mean', 'median', 'mode', 'knn', 'regression'
274
+ strategy : str
275
+ Strategy: 'columnwise', 'rowwise', 'global'
276
+ **kwargs : dict
277
+ Additional parameters for method
278
+
279
+ Returns:
280
+ --------
281
+ pd.DataFrame
282
+ Data with handled missing values
283
+ """
284
+ logger.info("\n" + "="*80)
285
+ logger.info("HANDLING MISSING VALUES")
286
+ logger.info("="*80)
287
+
288
+ data_processed = data.copy()
289
+ methods_applied = {}
290
+
291
+ # Determine columns to process
292
+ if strategy == 'columnwise':
293
+ columns_to_process = data_processed.columns
294
+ elif strategy == 'rowwise':
295
+ # Row-wise handling (for time series)
296
+ data_processed = self._handle_rowwise(data_processed, method, **kwargs)
297
+ return data_processed
298
+ else:
299
+ columns_to_process = data_processed.select_dtypes(include=[np.number]).columns
300
+
301
+ # Process each column
302
+ for col in columns_to_process:
303
+ missing_before = data_processed[col].isnull().sum()
304
+
305
+ if missing_before > 0:
306
+ # Check if missing percentage exceeds threshold
307
+ missing_percent = (missing_before / len(data_processed)) * 100
308
+
309
+ if missing_percent > self.config.missing_threshold:
310
+ logger.warning(f" {col}: {missing_before} missing ({missing_percent:.1f}%) > threshold {self.config.missing_threshold}%")
311
+
312
+ if kwargs.get('drop_high_missing', False):
313
+ data_processed = data_processed.drop(columns=[col])
314
+ method_used = f"dropped (>{self.config.missing_threshold}% missing)"
315
+ missing_after = 0
316
+ else:
317
+ # Use selected method
318
+ data_processed[col], method_used = self._apply_imputation_method(
319
+ data_processed[col], method, **kwargs
320
+ )
321
+ missing_after = data_processed[col].isnull().sum()
322
+ else:
323
+ # Use selected method
324
+ data_processed[col], method_used = self._apply_imputation_method(
325
+ data_processed[col], method, **kwargs
326
+ )
327
+ missing_after = data_processed[col].isnull().sum()
328
+
329
+ methods_applied[col] = {
330
+ 'method': method_used,
331
+ 'missing_before': int(missing_before),
332
+ 'missing_after': int(missing_after),
333
+ 'missing_percent_before': float(missing_percent)
334
+ }
335
+
336
+ if missing_before > 0:
337
+ logger.info(f" {col}: {missing_before} → {missing_after} missing ({method_used})")
338
+
339
+ self.handling_methods = methods_applied
340
+
341
+ # Check that all missing values are handled
342
+ remaining_missing = data_processed.isnull().sum().sum()
343
+ if remaining_missing == 0:
344
+ logger.info("✓ All missing values successfully handled")
345
+ else:
346
+ logger.warning(f"⚠ {remaining_missing} missing values remain")
347
+ # Additional handling of remaining missing values
348
+ data_processed = data_processed.fillna(method='ffill').fillna(method='bfill')
349
+ remaining_after = data_processed.isnull().sum().sum()
350
+ if remaining_after == 0:
351
+ logger.info("✓ Remaining missing values handled with ffill/bfill combination")
352
+
353
+ return data_processed
354
+
355
+ def _apply_imputation_method(
356
+ self,
357
+ series: pd.Series,
358
+ method: str,
359
+ **kwargs
360
+ ) -> Tuple[pd.Series, str]:
361
+ """
362
+ Apply imputation method to individual series
363
+
364
+ Parameters:
365
+ -----------
366
+ series : pd.Series
367
+ Input series
368
+ method : str
369
+ Imputation method
370
+ **kwargs : dict
371
+ Additional parameters
372
+
373
+ Returns:
374
+ --------
375
+ Tuple[pd.Series, str]
376
+ Processed series and method description
377
+ """
378
+ if method == 'interpolate':
379
+ # Interpolation for time series
380
+ if isinstance(series.index, pd.DatetimeIndex):
381
+ method_name = f"{kwargs.get('interpolation_method', 'linear')} interpolation"
382
+ series_filled = series.interpolate(
383
+ method=kwargs.get('interpolation_method', 'linear'),
384
+ limit_direction=kwargs.get('limit_direction', 'both'),
385
+ limit=kwargs.get('limit', None)
386
+ )
387
+ else:
388
+ method_name = 'linear interpolation'
389
+ series_filled = series.interpolate(method='linear')
390
+
391
+ elif method == 'time_weighted':
392
+ # Time-weighted interpolation
393
+ method_name = 'time-weighted interpolation'
394
+ series_filled = self._time_weighted_interpolation(series)
395
+
396
+ elif method == 'seasonal':
397
+ # Seasonal interpolation
398
+ method_name = 'seasonal interpolation'
399
+ series_filled = self._seasonal_interpolation(series, **kwargs)
400
+
401
+ elif method == 'ffill':
402
+ # Forward fill
403
+ method_name = 'forward fill'
404
+ series_filled = series.ffill(limit=kwargs.get('limit', None))
405
+
406
+ elif method == 'bfill':
407
+ # Backward fill
408
+ method_name = 'backward fill'
409
+ series_filled = series.bfill(limit=kwargs.get('limit', None))
410
+
411
+ elif method == 'mean':
412
+ # Mean imputation
413
+ method_name = 'mean imputation'
414
+ series_filled = series.fillna(series.mean())
415
+
416
+ elif method == 'median':
417
+ # Median imputation
418
+ method_name = 'median imputation'
419
+ series_filled = series.fillna(series.median())
420
+
421
+ elif method == 'mode':
422
+ # Mode imputation
423
+ method_name = 'mode imputation'
424
+ mode_value = series.mode()
425
+ if not mode_value.empty:
426
+ series_filled = series.fillna(mode_value.iloc[0])
427
+ else:
428
+ series_filled = series.fillna(series.median())
429
+
430
+ elif method == 'knn':
431
+ # KNN imputation
432
+ method_name = f"KNN imputation (k={kwargs.get('k', 5)})"
433
+ # Simplified version using nearest neighbour mean
434
+ series_filled = self._knn_imputation(series, k=kwargs.get('k', 5))
435
+
436
+ elif method == 'regression':
437
+ # Regression imputation
438
+ method_name = 'regression imputation'
439
+ series_filled = self._regression_imputation(series, **kwargs)
440
+
441
+ elif method == 'spline':
442
+ # Spline interpolation
443
+ method_name = 'spline interpolation'
444
+ series_filled = series.interpolate(method='spline', order=kwargs.get('order', 3))
445
+
446
+ elif method == 'stl':
447
+ # STL decomposition + interpolation
448
+ method_name = 'STL-based imputation'
449
+ series_filled = self._stl_imputation(series, **kwargs)
450
+
451
+ else:
452
+ raise ValueError(f"Unknown method: {method}")
453
+
454
+ # If missing values remain, fill with ffill/bfill
455
+ if series_filled.isnull().any():
456
+ series_filled = series_filled.ffill().bfill()
457
+ method_name += " + ffill/bfill"
458
+
459
+ return series_filled, method_name
460
+
461
+ def _time_weighted_interpolation(self, series: pd.Series) -> pd.Series:
462
+ """Time-weighted interpolation"""
463
+ if not isinstance(series.index, pd.DatetimeIndex):
464
+ return series.interpolate()
465
+
466
+ # Create timestamps
467
+ time_numeric = pd.Series(range(len(series)), index=series.index)
468
+
469
+ # Interpolate timestamps for missing values
470
+ time_interpolated = time_numeric.interpolate()
471
+
472
+ # Interpolate values based on timestamps
473
+ valid_mask = series.notna()
474
+ if valid_mask.sum() < 2:
475
+ return series.ffill().bfill()
476
+
477
+ # Use linear interpolation
478
+ valid_times = time_numeric[valid_mask]
479
+ valid_values = series[valid_mask]
480
+
481
+ # Interpolation
482
+ interp_func = interp1d(
483
+ valid_times,
484
+ valid_values,
485
+ kind='linear',
486
+ bounds_error=False,
487
+ fill_value='extrapolate'
488
+ )
489
+
490
+ series_filled = series.copy()
491
+ missing_mask = series.isna()
492
+ series_filled[missing_mask] = interp_func(time_interpolated[missing_mask])
493
+
494
+ return series_filled
495
+
496
+ def _seasonal_interpolation(
497
+ self,
498
+ series: pd.Series,
499
+ **kwargs
500
+ ) -> pd.Series:
501
+ """Seasonal interpolation"""
502
+ if not isinstance(series.index, pd.DatetimeIndex):
503
+ return series.interpolate()
504
+
505
+ period = kwargs.get('period', self.config.seasonal_period)
506
+
507
+ # Create series copy
508
+ series_filled = series.copy()
509
+
510
+ # Interpolation considering seasonality
511
+ for i in range(len(series)):
512
+ if pd.isna(series.iloc[i]):
513
+ # Find values at same seasonal position
514
+ seasonal_indices = []
515
+ for offset in range(1, 10): # Look in previous/next cycles
516
+ idx_back = i - offset * period
517
+ idx_forward = i + offset * period
518
+
519
+ if idx_back >= 0 and not pd.isna(series.iloc[idx_back]):
520
+ seasonal_indices.append(idx_back)
521
+
522
+ if idx_forward < len(series) and not pd.isna(series.iloc[idx_forward]):
523
+ seasonal_indices.append(idx_forward)
524
+
525
+ if seasonal_indices:
526
+ # Take mean value from seasonal positions
527
+ seasonal_values = series.iloc[seasonal_indices]
528
+ series_filled.iloc[i] = seasonal_values.mean()
529
+
530
+ # Fill remaining missing values with regular interpolation
531
+ series_filled = series_filled.interpolate()
532
+
533
+ return series_filled
534
+
535
+ def _knn_imputation(
536
+ self,
537
+ series: pd.Series,
538
+ k: int = 5
539
+ ) -> pd.Series:
540
+ """KNN imputation for time series"""
541
+ # Simplified KNN for time series
542
+ series_filled = series.copy()
543
+
544
+ for i in range(len(series)):
545
+ if pd.isna(series.iloc[i]):
546
+ # Find nearest k non-missing values
547
+ distances = []
548
+ values = []
549
+
550
+ for j in range(max(0, i - k * 10), min(len(series), i + k * 10)):
551
+ if j != i and not pd.isna(series.iloc[j]):
552
+ distance = abs(i - j)
553
+ distances.append(distance)
554
+ values.append(series.iloc[j])
555
+
556
+ if len(values) >= k:
557
+ break
558
+
559
+ if values:
560
+ # Distance-weighted average
561
+ weights = [1 / (d + 1) for d in distances]
562
+ weighted_avg = np.average(values, weights=weights)
563
+ series_filled.iloc[i] = weighted_avg
564
+
565
+ return series_filled
566
+
567
+ def _regression_imputation(
568
+ self,
569
+ series: pd.Series,
570
+ **kwargs
571
+ ) -> pd.Series:
572
+ """Regression imputation based on neighbouring values"""
573
+ # Simplified regression for time series
574
+ series_filled = series.copy()
575
+
576
+ if series.notna().sum() < 3:
577
+ return series.ffill().bfill()
578
+
579
+ # Use polynomial regression
580
+ x = np.arange(len(series))
581
+ y = series.values
582
+
583
+ # Valid values mask
584
+ valid_mask = ~np.isnan(y)
585
+
586
+ if valid_mask.sum() < 2:
587
+ return series.ffill().bfill()
588
+
589
+ # Polynomial regression degree 2
590
+ coeffs = np.polyfit(x[valid_mask], y[valid_mask], 2)
591
+ poly_func = np.poly1d(coeffs)
592
+
593
+ # Fill missing values
594
+ missing_mask = np.isnan(y)
595
+ series_filled.iloc[missing_mask] = poly_func(x[missing_mask])
596
+
597
+ return series_filled
598
+
599
+ def _stl_imputation(
600
+ self,
601
+ series: pd.Series,
602
+ **kwargs
603
+ ) -> pd.Series:
604
+ """STL decomposition-based imputation"""
605
+ try:
606
+ if not isinstance(series.index, pd.DatetimeIndex):
607
+ return series.interpolate()
608
+
609
+ # STL decomposition
610
+ stl = STL(
611
+ series.ffill().bfill(), # Fill missing for STL
612
+ period=kwargs.get('period', self.config.seasonal_period),
613
+ robust=True
614
+ )
615
+ result = stl.fit()
616
+
617
+ # Reconstruct series without noise
618
+ reconstructed = result.trend + result.seasonal
619
+
620
+ # Replace missing values with reconstructed values
621
+ series_filled = series.copy()
622
+ missing_mask = series.isna()
623
+ series_filled[missing_mask] = reconstructed[missing_mask]
624
+
625
+ return series_filled
626
+
627
+ except Exception as e:
628
+ logger.warning(f"STL imputation failed: {e}, using interpolation")
629
+ return series.interpolate()
630
+
631
+ def _handle_rowwise(
632
+ self,
633
+ data: pd.DataFrame,
634
+ method: str,
635
+ **kwargs
636
+ ) -> pd.DataFrame:
637
+ """Row-wise missing value handling"""
638
+ data_processed = data.copy()
639
+
640
+ # Remove rows with high missing counts
641
+ if kwargs.get('drop_rows_threshold', 0) > 0:
642
+ threshold = kwargs['drop_rows_threshold']
643
+ rows_before = len(data_processed)
644
+ missing_per_row = data_processed.isnull().sum(axis=1) / data_processed.shape[1] * 100
645
+ rows_to_drop = missing_per_row[missing_per_row > threshold].index
646
+ data_processed = data_processed.drop(rows_to_drop)
647
+ rows_after = len(data_processed)
648
+ logger.info(f"Rows removed: {rows_before - rows_after} (missing > {threshold}%)")
649
+
650
+ # Row-wise imputation
651
+ if method == 'row_mean':
652
+ data_processed = data_processed.T.fillna(data_processed.mean(axis=1)).T
653
+ elif method == 'row_median':
654
+ data_processed = data_processed.T.fillna(data_processed.median(axis=1)).T
655
+ elif method == 'row_ffill':
656
+ data_processed = data_processed.ffill(axis=1).bfill(axis=1)
657
+
658
+ return data_processed
659
+
660
+ def create_validation_rules(self) -> Dict:
661
+ """Create validation rules based on missing value analysis"""
662
+ rules = {}
663
+
664
+ for col, info in self.missing_info['summary'].items():
665
+ missing_percent = info['missing_percent']
666
+
667
+ if missing_percent > 50:
668
+ rules[col] = {
669
+ 'action': 'drop_column',
670
+ 'reason': f'Missing > 50%: {missing_percent:.1f}%'
671
+ }
672
+ elif missing_percent > 20:
673
+ rules[col] = {
674
+ 'action': 'advanced_imputation',
675
+ 'reason': f'High missing: {missing_percent:.1f}%',
676
+ 'recommended_method': 'knn'
677
+ }
678
+ elif missing_percent > 5:
679
+ rules[col] = {
680
+ 'action': 'standard_imputation',
681
+ 'reason': f'Moderate missing: {missing_percent:.1f}%',
682
+ 'recommended_method': 'interpolate'
683
+ }
684
+ elif missing_percent > 0:
685
+ rules[col] = {
686
+ 'action': 'simple_imputation',
687
+ 'reason': f'Low missing: {missing_percent:.1f}%',
688
+ 'recommended_method': 'ffill'
689
+ }
690
+
691
+ return rules
692
+
693
+ def get_report(self) -> Dict:
694
+ """Get missing values report"""
695
+ return {
696
+ 'missing_info': self.missing_info,
697
+ 'handling_methods': self.handling_methods,
698
+ 'missing_patterns': self.missing_patterns,
699
+ 'validation_rules': self.create_validation_rules()
700
+ }
outliers/__init__.py ADDED
File without changes
outliers/outlier_analyzer.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 4: OUTLIER ANALYSER
3
+ # ============================================
4
+ from typing import Dict, List, Tuple
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+ import pandas as pd
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.neighbors import LocalOutlierFactor
12
+ from sklearn.covariance import EllipticEnvelope
13
+ from scipy import stats
14
+
15
+ class OutlierAnalyser:
16
+ """Class for analysing and handling outliers"""
17
+
18
+ def __init__(self, config: Config):
19
+ """
20
+ Initialise outlier analyser
21
+
22
+ Parameters:
23
+ -----------
24
+ config : Config
25
+ Experiment configuration
26
+ """
27
+ self.config = config
28
+ self.outlier_info = {}
29
+ self.handling_methods = {}
30
+ self.detection_methods = {}
31
+ self.outlier_models = {}
32
+
33
+ def analyse(
34
+ self,
35
+ data: pd.DataFrame,
36
+ method: str = None,
37
+ columns: List[str] = None,
38
+ **kwargs
39
+ ) -> Dict:
40
+ """
41
+ Analyse outliers in data
42
+
43
+ Parameters:
44
+ -----------
45
+ data : pd.DataFrame
46
+ Input data
47
+ method : str, optional
48
+ Detection method. If None, uses configuration value.
49
+ columns : List[str], optional
50
+ List of columns to analyse. If None, uses all numeric columns.
51
+ **kwargs : dict
52
+ Additional parameters for method
53
+
54
+ Returns:
55
+ --------
56
+ Dict
57
+ Information about outliers
58
+ """
59
+ logger.info("\n" + "="*80)
60
+ logger.info("OUTLIER ANALYSIS")
61
+ logger.info("="*80)
62
+
63
+ method = method or self.config.outlier_method
64
+ if columns is None:
65
+ columns = data.select_dtypes(include=[np.number]).columns
66
+
67
+ outliers_info = {}
68
+
69
+ # Apply various detection methods
70
+ detection_results = {}
71
+
72
+ # 1. Statistical methods
73
+ if method in ['iqr', 'zscore', 'sigma', 'all']:
74
+ detection_results.update(self._statistical_methods(data, columns, method, **kwargs))
75
+
76
+ # 2. ML methods
77
+ if method in ['lof', 'isolation_forest', 'elliptic_envelope', 'all']:
78
+ detection_results.update(self._ml_methods(data, columns, method, **kwargs))
79
+
80
+ # 3. Temporal methods
81
+ if isinstance(data.index, pd.DatetimeIndex):
82
+ detection_results.update(self._temporal_methods(data, columns, **kwargs))
83
+
84
+ # Aggregate results
85
+ for col in columns:
86
+ if col in detection_results:
87
+ # Combine results from different methods
88
+ combined_mask = self._combine_detection_methods(detection_results, col)
89
+
90
+ outliers_count = combined_mask.sum()
91
+ outliers_percent = (outliers_count / len(data)) * 100
92
+
93
+ # Detailed information
94
+ col_data = data[col].dropna()
95
+ stats = {
96
+ 'mean': float(col_data.mean()),
97
+ 'std': float(col_data.std()),
98
+ 'median': float(col_data.median()),
99
+ 'q1': float(col_data.quantile(0.25)),
100
+ 'q3': float(col_data.quantile(0.75)),
101
+ 'min': float(col_data.min()),
102
+ 'max': float(col_data.max()),
103
+ 'skewness': float(col_data.skew()),
104
+ 'kurtosis': float(col_data.kurtosis())
105
+ }
106
+
107
+ outliers_info[col] = {
108
+ 'method': method,
109
+ 'statistics': stats,
110
+ 'outliers_count': int(outliers_count),
111
+ 'outliers_percent': float(outliers_percent),
112
+ 'outlier_indices': data[combined_mask].index.tolist() if outliers_count > 0 else [],
113
+ 'outlier_values': data.loc[combined_mask, col].tolist() if outliers_count > 0 else [],
114
+ 'detection_methods': {
115
+ name: {
116
+ 'count': int(mask.sum()),
117
+ 'percent': float(mask.sum() / len(data) * 100)
118
+ }
119
+ for name, mask in detection_results[col].items()
120
+ }
121
+ }
122
+
123
+ logger.info(f"{col}: {outliers_count} outliers ({outliers_percent:.2f}%)")
124
+
125
+ self.outlier_info = outliers_info
126
+ self.detection_methods = detection_results
127
+
128
+ # Visualisation
129
+ if self.config.save_plots and len(columns) > 0:
130
+ self._plot_outlier_analysis(data, columns, outliers_info)
131
+
132
+ return outliers_info
133
+
134
+ def _statistical_methods(
135
+ self,
136
+ data: pd.DataFrame,
137
+ columns: List[str],
138
+ method: str,
139
+ **kwargs
140
+ ) -> Dict:
141
+ """Statistical outlier detection methods"""
142
+ results = {}
143
+
144
+ for col in columns:
145
+ col_results = {}
146
+ series = data[col].dropna()
147
+
148
+ if len(series) < 3:
149
+ continue
150
+
151
+ # IQR method
152
+ if method in ['iqr', 'all']:
153
+ q1 = series.quantile(0.25)
154
+ q3 = series.quantile(0.75)
155
+ iqr = q3 - q1
156
+ lower_bound = q1 - self.config.outlier_alpha * iqr
157
+ upper_bound = q3 + self.config.outlier_alpha * iqr
158
+
159
+ iqr_mask = (data[col] < lower_bound) | (data[col] > upper_bound)
160
+ col_results['iqr'] = iqr_mask
161
+
162
+ # Z-score method
163
+ if method in ['zscore', 'sigma', 'all']:
164
+ z_threshold = kwargs.get('z_threshold', 3)
165
+ z_scores = np.abs((data[col] - series.mean()) / series.std())
166
+ z_mask = z_scores > z_threshold
167
+ col_results['zscore'] = z_mask
168
+
169
+ # Modified Z-score method
170
+ if method in ['zscore', 'all']:
171
+ median = series.median()
172
+ mad = np.median(np.abs(series - median))
173
+ if mad != 0:
174
+ modified_z_scores = 0.6745 * (data[col] - median) / mad
175
+ mz_mask = np.abs(modified_z_scores) > 3.5
176
+ col_results['modified_zscore'] = mz_mask
177
+
178
+ # Tukey's fences
179
+ if method in ['iqr', 'all']:
180
+ inner_lower = q1 - 1.5 * iqr
181
+ inner_upper = q3 + 1.5 * iqr
182
+ outer_lower = q1 - 3 * iqr
183
+ outer_upper = q3 + 3 * iqr
184
+
185
+ mild_mask = ((data[col] < inner_lower) | (data[col] > inner_upper)) & \
186
+ ((data[col] >= outer_lower) & (data[col] <= outer_upper))
187
+ extreme_mask = (data[col] < outer_lower) | (data[col] > outer_upper)
188
+
189
+ col_results['tukey_mild'] = mild_mask
190
+ col_results['tukey_extreme'] = extreme_mask
191
+
192
+ results[col] = col_results
193
+
194
+ return results
195
+
196
+ def _ml_methods(
197
+ self,
198
+ data: pd.DataFrame,
199
+ columns: List[str],
200
+ method: str,
201
+ **kwargs
202
+ ) -> Dict:
203
+ """ML outlier detection methods"""
204
+ results = {}
205
+
206
+ numeric_data = data[columns].dropna()
207
+
208
+ if len(numeric_data) < 10:
209
+ return results
210
+
211
+ try:
212
+ # Local Outlier Factor
213
+ if method in ['lof', 'all']:
214
+ lof = LocalOutlierFactor(
215
+ contamination=self.config.outlier_contamination,
216
+ n_neighbors=kwargs.get('n_neighbors', 20)
217
+ )
218
+ lof_labels = lof.fit_predict(numeric_data)
219
+ lof_mask = pd.Series(lof_labels == -1, index=numeric_data.index)
220
+
221
+ for col in columns:
222
+ if col in numeric_data.columns:
223
+ if col not in results:
224
+ results[col] = {}
225
+ results[col]['lof'] = lof_mask
226
+
227
+ # Elliptic Envelope
228
+ if method in ['elliptic_envelope', 'all']:
229
+ try:
230
+ envelope = EllipticEnvelope(
231
+ contamination=self.config.outlier_contamination,
232
+ random_state=42
233
+ )
234
+ envelope_labels = envelope.fit_predict(numeric_data)
235
+ envelope_mask = pd.Series(envelope_labels == -1, index=numeric_data.index)
236
+
237
+ for col in columns:
238
+ if col in numeric_data.columns:
239
+ if col not in results:
240
+ results[col] = {}
241
+ results[col]['elliptic_envelope'] = envelope_mask
242
+ except Exception as e:
243
+ logger.warning(f"Elliptic Envelope failed: {e}")
244
+
245
+ except Exception as e:
246
+ logger.warning(f"ML outlier detection methods failed: {e}")
247
+
248
+ return results
249
+
250
+ def _temporal_methods(
251
+ self,
252
+ data: pd.DataFrame,
253
+ columns: List[str],
254
+ **kwargs
255
+ ) -> Dict:
256
+ """Outlier detection methods for time series"""
257
+ results = {}
258
+
259
+ for col in columns:
260
+ col_results = {}
261
+ series = data[col].dropna()
262
+
263
+ if len(series) < 30:
264
+ continue
265
+
266
+ # Rolling statistics method
267
+ window = kwargs.get('temporal_window', 30)
268
+ rolling_mean = series.rolling(window=window, center=True).mean()
269
+ rolling_std = series.rolling(window=window, center=True).std()
270
+
271
+ # Outliers relative to moving average
272
+ threshold = kwargs.get('temporal_threshold', 3)
273
+ temporal_mask = np.abs(series - rolling_mean) > (threshold * rolling_std)
274
+ col_results['temporal'] = temporal_mask
275
+
276
+ # Seasonal detrending + outlier detection
277
+ try:
278
+ # Simple seasonal detrending
279
+ if len(series) > 365:
280
+ seasonal_period = kwargs.get('seasonal_period', 365)
281
+ seasonal_mean = series.rolling(window=seasonal_period, center=True).mean()
282
+ detrended = series - seasonal_mean
283
+
284
+ # Outliers in detrended series
285
+ q1 = detrended.quantile(0.25)
286
+ q3 = detrended.quantile(0.75)
287
+ iqr = q3 - q1
288
+ seasonal_lower = q1 - 3 * iqr
289
+ seasonal_upper = q3 + 3 * iqr
290
+
291
+ seasonal_mask = (detrended < seasonal_lower) | (detrended > seasonal_upper)
292
+ col_results['seasonal'] = seasonal_mask
293
+ except Exception as e:
294
+ logger.debug(f"Seasonal outlier detection failed for {col}: {e}")
295
+
296
+ results[col] = col_results
297
+
298
+ return results
299
+
300
+ def _combine_detection_methods(
301
+ self,
302
+ detection_results: Dict,
303
+ column: str
304
+ ) -> pd.Series:
305
+ """Combine results from different detection methods"""
306
+ if column not in detection_results:
307
+ return pd.Series(False, index=pd.RangeIndex(0))
308
+
309
+ methods = detection_results[column]
310
+ combined_mask = None
311
+
312
+ for method_name, mask in methods.items():
313
+ if combined_mask is None:
314
+ combined_mask = mask.copy()
315
+ else:
316
+ # Combine via OR (outlier by any method)
317
+ combined_mask = combined_mask | mask
318
+
319
+ return combined_mask.fillna(False)
320
+
321
+ def _plot_outlier_analysis(
322
+ self,
323
+ data: pd.DataFrame,
324
+ columns: List[str],
325
+ outliers_info: Dict
326
+ ) -> None:
327
+ """Visualise outlier analysis"""
328
+ n_cols = min(len(columns), 4)
329
+ n_rows = (len(columns) + n_cols - 1) // n_cols
330
+
331
+ fig = plt.figure(figsize=(16, 4 * n_rows))
332
+ gs = fig.add_gridspec(n_rows, n_cols)
333
+
334
+ for idx, col in enumerate(columns):
335
+ if col not in outliers_info:
336
+ continue
337
+
338
+ row = idx // n_cols
339
+ col_idx = idx % n_cols
340
+
341
+ ax = fig.add_subplot(gs[row, col_idx])
342
+
343
+ # Data
344
+ series = data[col].dropna()
345
+
346
+ # 1. Box plot
347
+ bp = ax.boxplot(
348
+ series.values,
349
+ vert=True,
350
+ patch_artist=True,
351
+ widths=0.6,
352
+ showfliers=False
353
+ )
354
+
355
+ # Colours for box plot
356
+ bp['boxes'][0].set_facecolor('lightblue')
357
+ bp['medians'][0].set_color('red')
358
+ bp['whiskers'][0].set_color('black')
359
+ bp['whiskers'][1].set_color('black')
360
+ bp['caps'][0].set_color('black')
361
+ bp['caps'][1].set_color('black')
362
+
363
+ # 2. Outliers
364
+ if outliers_info[col]['outliers_count'] > 0:
365
+ outlier_indices = outliers_info[col]['outlier_indices']
366
+ outlier_values = outliers_info[col]['outlier_values']
367
+
368
+ # Convert indices to positions for box plot
369
+ jitter = np.random.normal(0, 0.05, len(outlier_values))
370
+
371
+ ax.scatter(
372
+ np.ones(len(outlier_values)) + jitter,
373
+ outlier_values,
374
+ color='red',
375
+ alpha=0.6,
376
+ s=30,
377
+ edgecolors='black',
378
+ label=f'Outliers ({outliers_info[col]["outliers_count"]})'
379
+ )
380
+
381
+ # 3. Histogram on same plot
382
+ ax2 = ax.twinx()
383
+ ax2.hist(
384
+ series.values,
385
+ bins=30,
386
+ alpha=0.3,
387
+ color='green',
388
+ density=True
389
+ )
390
+
391
+ # 4. Normal distribution for comparison
392
+ if len(series) > 10:
393
+ xmin, xmax = ax.get_xlim()
394
+ x = np.linspace(series.min(), series.max(), 100)
395
+ mean = series.mean()
396
+ std = series.std()
397
+
398
+ if std > 0:
399
+ p = stats.norm.pdf(x, mean, std)
400
+ ax2.plot(x, p, 'k--', linewidth=1, label='Normal distribution')
401
+
402
+ ax.set_title(f'{col}\nOutliers: {outliers_info[col]["outliers_count"]} ({outliers_info[col]["outliers_percent"]:.1f}%)')
403
+ ax.set_ylabel('Value')
404
+ ax.grid(True, alpha=0.3)
405
+
406
+ # Legend
407
+ if outliers_info[col]['outliers_count'] > 0:
408
+ ax.legend(loc='upper right', fontsize=8)
409
+ ax2.legend(loc='upper left', fontsize=8)
410
+
411
+ plt.tight_layout()
412
+ plt.savefig(
413
+ f'{self.config.results_dir}/plots/outliers_analysis.png',
414
+ dpi=300,
415
+ bbox_inches='tight'
416
+ )
417
+ plt.show()
418
+
419
+ # Additional plots for time series
420
+ if isinstance(data.index, pd.DatetimeIndex) and len(columns) > 0:
421
+ self._plot_temporal_outliers(data, columns, outliers_info)
422
+
423
+ def _plot_temporal_outliers(
424
+ self,
425
+ data: pd.DataFrame,
426
+ columns: List[str],
427
+ outliers_info: Dict
428
+ ) -> None:
429
+ """Visualise outliers over time"""
430
+ n_plots = min(len(columns), 3)
431
+
432
+ fig, axes = plt.subplots(n_plots, 1, figsize=(14, 4 * n_plots))
433
+ if n_plots == 1:
434
+ axes = [axes]
435
+
436
+ for idx, (col, ax) in enumerate(zip(columns[:n_plots], axes)):
437
+ if col not in outliers_info:
438
+ continue
439
+
440
+ # Time series
441
+ ax.plot(data.index, data[col], alpha=0.7, linewidth=1, label='Original series')
442
+
443
+ # Outliers
444
+ if outliers_info[col]['outliers_count'] > 0:
445
+ outlier_indices = outliers_info[col]['outlier_indices']
446
+ outlier_values = outliers_info[col]['outlier_values']
447
+
448
+ ax.scatter(
449
+ outlier_indices,
450
+ outlier_values,
451
+ color='red',
452
+ s=40,
453
+ edgecolors='black',
454
+ zorder=5,
455
+ label='Outliers'
456
+ )
457
+
458
+ # Moving average
459
+ if len(data) > 30:
460
+ rolling_mean = data[col].rolling(window=30, center=True).mean()
461
+ ax.plot(data.index, rolling_mean, 'orange', linewidth=2, label='Moving average (30)')
462
+
463
+ ax.set_title(f'Outliers over time: {col}')
464
+ ax.set_xlabel('Date')
465
+ ax.set_ylabel(col)
466
+ ax.legend(fontsize=8)
467
+ ax.grid(True, alpha=0.3)
468
+
469
+ plt.tight_layout()
470
+ plt.savefig(
471
+ f'{self.config.results_dir}/plots/temporal_outliers.png',
472
+ dpi=300,
473
+ bbox_inches='tight'
474
+ )
475
+ plt.show()
476
+
477
+ def handle(
478
+ self,
479
+ data: pd.DataFrame,
480
+ method: str = 'clip',
481
+ strategy: str = 'columnwise',
482
+ **kwargs
483
+ ) -> pd.DataFrame:
484
+ """
485
+ Handle outliers
486
+
487
+ Parameters:
488
+ -----------
489
+ data : pd.DataFrame
490
+ Input data
491
+ method : str
492
+ Handling method: 'clip', 'remove', 'mean', 'median', 'winsorize', 'transform', 'impute'
493
+ strategy : str
494
+ Strategy: 'columnwise', 'global', 'adaptive'
495
+ **kwargs : dict
496
+ Additional parameters for method
497
+
498
+ Returns:
499
+ --------
500
+ pd.DataFrame
501
+ Data with handled outliers
502
+ """
503
+ logger.info("\n" + "="*80)
504
+ logger.info("HANDLING OUTLIERS")
505
+ logger.info("="*80)
506
+
507
+ if not self.outlier_info:
508
+ logger.warning("⚠ Perform outlier analysis first")
509
+ return data
510
+
511
+ data_processed = data.copy()
512
+ methods_applied = {}
513
+
514
+ for col, info in self.outlier_info.items():
515
+ if col not in data_processed.columns:
516
+ continue
517
+
518
+ outliers_count = info['outliers_count']
519
+
520
+ if outliers_count > 0:
521
+ # Create outlier mask
522
+ outlier_mask = pd.Series(False, index=data_processed.index)
523
+ if info['outlier_indices']:
524
+ outlier_indices = [idx for idx in info['outlier_indices'] if idx in data_processed.index]
525
+ outlier_mask.loc[outlier_indices] = True
526
+
527
+ # Determine boundaries
528
+ stats = info['statistics']
529
+ q1, q3 = stats['q1'], stats['q3']
530
+ iqr = q3 - q1
531
+ lower_bound = q1 - self.config.outlier_alpha * iqr
532
+ upper_bound = q3 + self.config.outlier_alpha * iqr
533
+
534
+ if method == 'clip':
535
+ # Clip values to boundaries
536
+ data_processed[col] = data_processed[col].clip(
537
+ lower=lower_bound,
538
+ upper=upper_bound
539
+ )
540
+ method_used = 'clipping'
541
+ affected = outliers_count
542
+
543
+ elif method == 'remove':
544
+ # Remove rows with outliers
545
+ data_processed = data_processed[~outlier_mask]
546
+ method_used = 'removal'
547
+ affected = outliers_count
548
+
549
+ elif method == 'mean':
550
+ # Replace outliers with mean value
551
+ mean_val = data_processed[col].mean()
552
+ data_processed.loc[outlier_mask, col] = mean_val
553
+ method_used = 'mean imputation'
554
+ affected = outliers_count
555
+
556
+ elif method == 'median':
557
+ # Replace outliers with median
558
+ median_val = data_processed[col].median()
559
+ data_processed.loc[outlier_mask, col] = median_val
560
+ method_used = 'median imputation'
561
+ affected = outliers_count
562
+
563
+ elif method == 'winsorize':
564
+ # Winsorisation
565
+ data_processed[col] = self._winsorize_series(
566
+ data_processed[col],
567
+ limits=kwargs.get('limits', (0.05, 0.05))
568
+ )
569
+ method_used = 'winsorization'
570
+ affected = outliers_count
571
+
572
+ elif method == 'transform':
573
+ # Transformation to reduce outlier impact
574
+ transform_method = kwargs.get('transform_method', 'log')
575
+ data_processed[col] = self._transform_series(
576
+ data_processed[col],
577
+ method=transform_method
578
+ )
579
+ method_used = f'{transform_method} transformation'
580
+ affected = 'all' # Transformation applied to all values
581
+
582
+ elif method == 'impute':
583
+ # Smart outlier imputation
584
+ impute_method = kwargs.get('impute_method', 'neighbors')
585
+ data_processed[col] = self._impute_outliers(
586
+ data_processed[col],
587
+ outlier_mask,
588
+ method=impute_method,
589
+ **kwargs
590
+ )
591
+ method_used = f'{impute_method} imputation'
592
+ affected = outliers_count
593
+
594
+ elif method == 'adaptive':
595
+ # Adaptive handling
596
+ data_processed[col] = self._adaptive_outlier_handling(
597
+ data_processed[col],
598
+ outlier_mask,
599
+ **kwargs
600
+ )
601
+ method_used = 'adaptive handling'
602
+ affected = outliers_count
603
+
604
+ else:
605
+ raise ValueError(f"Unknown method: {method}")
606
+
607
+ methods_applied[col] = {
608
+ 'method': method_used,
609
+ 'outliers_before': outliers_count,
610
+ 'affected': affected,
611
+ 'bounds': {
612
+ 'lower': float(lower_bound),
613
+ 'upper': float(upper_bound)
614
+ }
615
+ }
616
+
617
+ logger.info(f" {col}: {outliers_count} outliers handled ({method_used})")
618
+
619
+ self.handling_methods = methods_applied
620
+
621
+ # Handling statistics
622
+ total_outliers = sum(info['outliers_count'] for info in self.outlier_info.values())
623
+ total_affected = sum(method['affected'] for method in methods_applied.values()
624
+ if isinstance(method['affected'], (int, np.integer)))
625
+
626
+ logger.info(f"\n✓ {total_affected} out of {total_outliers} outliers handled")
627
+ logger.info(f" Data size before: {len(data)} rows")
628
+ logger.info(f" Data size after: {len(data_processed)} rows")
629
+
630
+ # Visualise results
631
+ if self.config.save_plots and methods_applied:
632
+ self._plot_outlier_handling_results(data, data_processed, methods_applied)
633
+
634
+ return data_processed
635
+
636
+ def _winsorize_series(
637
+ self,
638
+ series: pd.Series,
639
+ limits: Tuple[float, float] = (0.05, 0.05)
640
+ ) -> pd.Series:
641
+ """Winsorize series"""
642
+ from scipy.stats.mstats import winsorize
643
+ try:
644
+ winsorized = winsorize(series.values, limits=limits)
645
+ return pd.Series(winsorized, index=series.index)
646
+ except:
647
+ return series
648
+
649
+ def _transform_series(
650
+ self,
651
+ series: pd.Series,
652
+ method: str = 'log'
653
+ ) -> pd.Series:
654
+ """Transform series to reduce outlier impact"""
655
+ series_transformed = series.copy()
656
+
657
+ if method == 'log':
658
+ # Logarithmic transformation
659
+ min_val = series.min()
660
+ if min_val <= 0:
661
+ shift = abs(min_val) + 1
662
+ series_transformed = np.log(series + shift)
663
+ else:
664
+ series_transformed = np.log(series)
665
+
666
+ elif method == 'boxcox':
667
+ # Box-Cox transformation
668
+ try:
669
+ from scipy.stats import boxcox
670
+ transformed, _ = boxcox(series - series.min() + 1)
671
+ series_transformed = pd.Series(transformed, index=series.index)
672
+ except:
673
+ logger.warning("Box-Cox transformation failed, using log")
674
+ return self._transform_series(series, 'log')
675
+
676
+ elif method == 'sqrt':
677
+ # Square root
678
+ min_val = series.min()
679
+ if min_val < 0:
680
+ series_transformed = np.sqrt(series - min_val)
681
+ else:
682
+ series_transformed = np.sqrt(series)
683
+
684
+ elif method == 'yeojohnson':
685
+ # Yeo-Johnson transformation
686
+ try:
687
+ from scipy.stats import yeojohnson
688
+ transformed, _ = yeojohnson(series)
689
+ series_transformed = pd.Series(transformed, index=series.index)
690
+ except:
691
+ logger.warning("Yeo-Johnson transformation failed, using log")
692
+ return self._transform_series(series, 'log')
693
+
694
+ return series_transformed
695
+
696
+ def _impute_outliers(
697
+ self,
698
+ series: pd.Series,
699
+ outlier_mask: pd.Series,
700
+ method: str = 'neighbors',
701
+ **kwargs
702
+ ) -> pd.Series:
703
+ """Smart outlier imputation"""
704
+ series_imputed = series.copy()
705
+
706
+ if method == 'neighbors':
707
+ # Replace with mean of neighbouring values
708
+ for idx in series[outlier_mask].index:
709
+ if idx in series.index:
710
+ pos = series.index.get_loc(idx)
711
+ neighbours = []
712
+
713
+ # Find nearest non-outliers
714
+ for offset in range(1, 6):
715
+ if pos - offset >= 0 and not outlier_mask.iloc[pos - offset]:
716
+ neighbours.append(series.iloc[pos - offset])
717
+ break
718
+
719
+ for offset in range(1, 6):
720
+ if pos + offset < len(series) and not outlier_mask.iloc[pos + offset]:
721
+ neighbours.append(series.iloc[pos + offset])
722
+ break
723
+
724
+ if neighbours:
725
+ series_imputed.loc[idx] = np.mean(neighbours)
726
+
727
+ elif method == 'interpolate':
728
+ # Interpolation
729
+ series_imputed = series.mask(outlier_mask).interpolate()
730
+
731
+ elif method == 'rolling':
732
+ # Replace with moving average
733
+ window = kwargs.get('window', 5)
734
+ rolling_mean = series.rolling(window=window, center=True, min_periods=1).mean()
735
+ series_imputed = series.mask(outlier_mask, rolling_mean)
736
+
737
+ return series_imputed
738
+
739
+ def _adaptive_outlier_handling(
740
+ self,
741
+ series: pd.Series,
742
+ outlier_mask: pd.Series,
743
+ **kwargs
744
+ ) -> pd.Series:
745
+ """Adaptive outlier handling"""
746
+ series_processed = series.copy()
747
+ outlier_indices = series[outlier_mask].index
748
+
749
+ for idx in outlier_indices:
750
+ if idx in series.index:
751
+ value = series.loc[idx]
752
+ stats = self.outlier_info.get(series.name, {}).get('statistics', {})
753
+
754
+ # Determine outlier type
755
+ q1 = stats.get('q1', series.quantile(0.25))
756
+ q3 = stats.get('q3', series.quantile(0.75))
757
+ iqr = q3 - q1
758
+
759
+ if value < q1 - 3 * iqr:
760
+ # Extreme low outlier
761
+ series_processed.loc[idx] = q1 - 1.5 * iqr
762
+ elif value > q3 + 3 * iqr:
763
+ # Extreme high outlier
764
+ series_processed.loc[idx] = q3 + 1.5 * iqr
765
+ else:
766
+ # Moderate outlier
767
+ pos = series.index.get_loc(idx)
768
+ # Use linear interpolation
769
+ if pos > 0 and pos < len(series) - 1:
770
+ series_processed.loc[idx] = (series.iloc[pos-1] + series.iloc[pos+1]) / 2
771
+
772
+ return series_processed
773
+
774
+ def _plot_outlier_handling_results(
775
+ self,
776
+ original_data: pd.DataFrame,
777
+ processed_data: pd.DataFrame,
778
+ methods_applied: Dict
779
+ ) -> None:
780
+ """Visualise outlier handling results"""
781
+ cols_to_plot = list(methods_applied.keys())[:3]
782
+
783
+ if not cols_to_plot:
784
+ return
785
+
786
+ fig, axes = plt.subplots(len(cols_to_plot), 2, figsize=(14, 4 * len(cols_to_plot)))
787
+ if len(cols_to_plot) == 1:
788
+ axes = axes.reshape(1, -1)
789
+
790
+ for idx, col in enumerate(cols_to_plot):
791
+ if col not in original_data.columns or col not in processed_data.columns:
792
+ continue
793
+
794
+ # Distribution before handling
795
+ axes[idx, 0].hist(original_data[col].dropna(), bins=30, alpha=0.5, label='Before', density=True)
796
+ axes[idx, 0].hist(processed_data[col].dropna(), bins=30, alpha=0.5, label='After', density=True)
797
+ axes[idx, 0].set_title(f'{col}: Distribution before/after')
798
+ axes[idx, 0].set_xlabel('Value')
799
+ axes[idx, 0].set_ylabel('Density')
800
+ axes[idx, 0].legend()
801
+ axes[idx, 0].grid(True, alpha=0.3)
802
+
803
+ # QQ plot for normality check
804
+ stats.probplot(original_data[col].dropna(), dist="norm", plot=axes[idx, 1])
805
+ axes[idx, 1].set_title(f'{col}: Q-Q plot (before handling)')
806
+ axes[idx, 1].grid(True, alpha=0.3)
807
+
808
+ plt.tight_layout()
809
+ plt.savefig(
810
+ f'{self.config.results_dir}/plots/outlier_handling_results.png',
811
+ dpi=300,
812
+ bbox_inches='tight'
813
+ )
814
+ plt.show()
815
+
816
+ def create_validation_rules(self) -> Dict:
817
+ """Create validation rules based on outlier analysis"""
818
+ rules = {}
819
+
820
+ for col, info in self.outlier_info.items():
821
+ outliers_percent = info['outliers_percent']
822
+ skewness = info['statistics']['skewness']
823
+
824
+ rule = {
825
+ 'outliers_percent': outliers_percent,
826
+ 'skewness': skewness,
827
+ 'recommended_action': 'none'
828
+ }
829
+
830
+ if outliers_percent > 10:
831
+ rule['recommended_action'] = 'aggressive_handling'
832
+ rule['reason'] = f'High outliers: {outliers_percent:.1f}%'
833
+ elif outliers_percent > 5:
834
+ rule['recommended_action'] = 'moderate_handling'
835
+ rule['reason'] = f'Moderate outliers: {outliers_percent:.1f}%'
836
+ elif outliers_percent > 1:
837
+ rule['recommended_action'] = 'conservative_handling'
838
+ rule['reason'] = f'Low outliers: {outliers_percent:.1f}%'
839
+
840
+ if abs(skewness) > 1:
841
+ rule['skewness_issue'] = True
842
+ rule['skewness_reason'] = f'Strong skewness: {skewness:.2f}'
843
+ if rule['recommended_action'] == 'none':
844
+ rule['recommended_action'] = 'transformation'
845
+
846
+ rules[col] = rule
847
+
848
+ return rules
849
+
850
+ def get_report(self) -> Dict:
851
+ """Get outlier analysis report"""
852
+ return {
853
+ 'outlier_info': self.outlier_info,
854
+ 'handling_methods': self.handling_methods,
855
+ 'detection_methods': self.detection_methods,
856
+ 'validation_rules': self.create_validation_rules()
857
+ }
pipeline/__init__.py ADDED
File without changes
pipeline/main_pipeline.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 14: MAIN PIPELINE
3
+ # ============================================
4
+ from datetime import datetime
5
+ import json
6
+ import os
7
+ import traceback
8
+ from typing import Any, Dict, Optional
9
+ from venv import logger
10
+ from config.config import Config
11
+ from correlations.correlation_analyzer import CorrelationAnalyzer
12
+ from data_loader.data_loader import DataLoader
13
+ from decomposition.decomposer import TimeSeriesDecomposer
14
+ from feature_selection.feature_selector import FeatureSelector
15
+ from features.feature_engineer import FeatureEngineer
16
+
17
+ from missing_values.missing_analyzer import MissingValueAnalyser
18
+ from outliers.outlier_analyzer import OutlierAnalyser
19
+ from scaling.data_scaler import DataScaler
20
+ from splitting.data_splitter import DataSplitter
21
+ from stationarity.stationarity_checker import StationarityChecker
22
+ from validation.data_validator import DataValidator
23
+ import pandas as pd
24
+ import numpy as np
25
+
26
+ from visualization.visualization_manager import VisualisationManager
27
+
28
+ class EnhancedDataPreprocessingPipeline:
29
+ """Enhanced main data preprocessing pipeline"""
30
+
31
+ def __init__(self, config: Config):
32
+ """
33
+ Initialise pipeline
34
+
35
+ Parameters:
36
+ -----------
37
+ config : Config
38
+ Experiment configuration
39
+ """
40
+ self.config = config
41
+ self.data_loader = DataLoader(config)
42
+ self.missing_analyser = MissingValueAnalyser(config)
43
+ self.outlier_analyser = OutlierAnalyser(config)
44
+ self.feature_engineer = FeatureEngineer(config)
45
+ self.stationarity_checker = StationarityChecker(config)
46
+ self.decomposer = TimeSeriesDecomposer(config)
47
+ self.correlation_analyser = CorrelationAnalyzer(config)
48
+ self.data_splitter = DataSplitter(config)
49
+ self.data_scaler = DataScaler(config)
50
+ self.feature_selector = FeatureSelector(config)
51
+ self.data_validator = DataValidator(config)
52
+ self.visualisation_manager = VisualisationManager(config)
53
+
54
+ self.results = {}
55
+ self.processed_data = None
56
+ self.train_data = None
57
+ self.val_data = None
58
+ self.test_data = None
59
+ self.is_fitted = False
60
+
61
+ def run_full_pipeline(
62
+ self,
63
+ data_path: Optional[str] = None,
64
+ use_synthetic: bool = False,
65
+ save_intermediate: bool = True,
66
+ create_reports: bool = True
67
+ ) -> pd.DataFrame:
68
+ """
69
+ Run enhanced full data preprocessing pipeline
70
+
71
+ Parameters:
72
+ -----------
73
+ data_path : str, optional
74
+ Path to data. If None, uses configuration value.
75
+ use_synthetic : bool
76
+ Use synthetic data for testing
77
+ save_intermediate : bool
78
+ Save intermediate results
79
+ create_reports : bool
80
+ Create reports
81
+
82
+ Returns:
83
+ --------
84
+ pd.DataFrame
85
+ Processed data
86
+ """
87
+ logger.info("\n" + "="*80)
88
+ logger.info("RUNNING ENHANCED DATA PREPROCESSING PIPELINE")
89
+ logger.info("="*80)
90
+
91
+ start_time = datetime.now()
92
+
93
+ try:
94
+ # Step 1: Data loading
95
+ logger.info("\n" + "="*80)
96
+ logger.info("STEP 1: DATA LOADING")
97
+ logger.info("="*80)
98
+
99
+ if use_synthetic:
100
+ data = self.data_loader.create_synthetic_data(
101
+ n_days=365*20,
102
+ trend_strength=0.01,
103
+ noise_std=10,
104
+ include_exogenous=True
105
+ )
106
+ else:
107
+ data = self.data_loader.load_from_csv(
108
+ data_path=data_path,
109
+ parse_dates=['date']
110
+ )
111
+
112
+ # Check for date index
113
+ if not isinstance(data.index, pd.DatetimeIndex):
114
+ logger.warning("Index is not DatetimeIndex, setting...")
115
+ if 'date' in data.columns:
116
+ data = data.set_index('date')
117
+ logger.info("Index set from 'date' column")
118
+
119
+ self.results['data_loading'] = {
120
+ 'shape': list(data.shape),
121
+ 'columns': list(data.columns),
122
+ 'date_range': {
123
+ 'min': data.index.min().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None,
124
+ 'max': data.index.max().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None
125
+ },
126
+ 'is_datetime_index': isinstance(data.index, pd.DatetimeIndex)
127
+ }
128
+
129
+ # Save raw data information
130
+ self.data_loader.save_raw_data_info()
131
+
132
+ # Step 2: Raw data validation
133
+ logger.info("\n" + "="*80)
134
+ logger.info("STEP 2: RAW DATA VALIDATION")
135
+ logger.info("="*80)
136
+
137
+ raw_validation = self.data_validator.validate(
138
+ data, stage='raw', detailed=True
139
+ )
140
+ self.results['raw_validation'] = raw_validation
141
+
142
+ if raw_validation['status'] == 'FAIL':
143
+ logger.warning("⚠ Raw data has critical issues!")
144
+ if not self.config.enable_validation:
145
+ logger.warning("Validation disabled in configuration, continuing processing")
146
+ else:
147
+ logger.error("Pipeline interrupted due to data issues")
148
+ return None
149
+
150
+ # Step 3: Missing values analysis and handling
151
+ logger.info("\n" + "="*80)
152
+ logger.info("STEP 3: MISSING VALUES HANDLING")
153
+ logger.info("="*80)
154
+
155
+ missing_info = self.missing_analyser.analyse(data, detailed=True)
156
+ self.results['missing_analysis'] = missing_info
157
+
158
+ # Handle missing values
159
+ data = self.missing_analyser.handle(
160
+ data,
161
+ method='interpolate',
162
+ strategy='columnwise'
163
+ )
164
+ self.results['missing_handling'] = self.missing_analyser.handling_methods
165
+
166
+ # Step 4: Outlier analysis and handling
167
+ logger.info("\n" + "="*80)
168
+ logger.info("STEP 4: OUTLIER HANDLING")
169
+ logger.info("="*80)
170
+
171
+ outlier_info = self.outlier_analyser.analyse(
172
+ data,
173
+ method=self.config.outlier_method,
174
+ columns=data.select_dtypes(include=[np.number]).columns.tolist()
175
+ )
176
+ self.results['outlier_analysis'] = outlier_info
177
+
178
+ # Handle outliers
179
+ data = self.outlier_analyser.handle(
180
+ data,
181
+ method='clip',
182
+ strategy='columnwise'
183
+ )
184
+ self.results['outlier_handling'] = self.outlier_analyser.handling_methods
185
+
186
+ # Step 5: Feature engineering
187
+ logger.info("\n" + "="*80)
188
+ logger.info("STEP 5: FEATURE ENGINEERING")
189
+ logger.info("="*80)
190
+
191
+ data = self.feature_engineer.create_all_features(data)
192
+ self.results['feature_engineering'] = self.feature_engineer.feature_info
193
+
194
+ # Check for data after feature engineering
195
+ if len(data) == 0:
196
+ logger.error("No data remaining after feature engineering!")
197
+ return None
198
+
199
+ # Step 6: Stationarity analysis
200
+ logger.info("\n" + "="*80)
201
+ logger.info("STEP 6: STATIONARITY ANALYSIS")
202
+ logger.info("="*80)
203
+
204
+ stationarity_results = self.stationarity_checker.check(
205
+ data,
206
+ target_col=self.config.target_column,
207
+ make_stationary=True,
208
+ try_transformations=True
209
+ )
210
+ self.results['stationarity_analysis'] = stationarity_results
211
+
212
+ # Step 7: Time series decomposition
213
+ logger.info("\n" + "="*80)
214
+ logger.info("STEP 7: TIME SERIES DECOMPOSITION")
215
+ logger.info("="*80)
216
+
217
+ if isinstance(data.index, pd.DatetimeIndex) and len(data) > 365:
218
+ decomposition_results = self.decomposer.decompose(
219
+ data,
220
+ target_col=self.config.target_column,
221
+ method='stl',
222
+ period=self.config.seasonal_period
223
+ )
224
+ self.results['decomposition'] = decomposition_results
225
+ else:
226
+ logger.info("Skipped: insufficient data or no DatetimeIndex")
227
+ self.results['decomposition'] = {'skipped': 'insufficient data or no DatetimeIndex'}
228
+
229
+ # Step 8: Correlation analysis
230
+ logger.info("\n" + "="*80)
231
+ logger.info("STEP 8: CORRELATION ANALYSIS")
232
+ logger.info("="*80)
233
+
234
+ corr_matrix = self.correlation_analyser.analyze(
235
+ data,
236
+ target_col=self.config.target_column,
237
+ threshold=0.8,
238
+ detailed=True
239
+ )
240
+ self.results['correlation_analysis'] = self.correlation_analyser.get_report()
241
+
242
+ # Remove highly correlated features
243
+ if not corr_matrix.empty:
244
+ data = self.correlation_analyser.remove_highly_correlated(
245
+ data,
246
+ threshold=0.95,
247
+ method='variance',
248
+ keep_target=True
249
+ )
250
+ else:
251
+ logger.warning("Correlation matrix empty, skipping feature removal")
252
+
253
+ # Step 9: Processed data validation
254
+ logger.info("\n" + "="*80)
255
+ logger.info("STEP 9: PROCESSED DATA VALIDATION")
256
+ logger.info("="*80)
257
+
258
+ processed_validation = self.data_validator.validate(
259
+ data, stage='processed', detailed=True
260
+ )
261
+ self.results['processed_validation'] = processed_validation
262
+
263
+ if processed_validation['status'] == 'FAIL':
264
+ logger.warning("⚠ Processed data failed validation!")
265
+ logger.warning("Continuing pipeline, but data quality may be low")
266
+ elif processed_validation['status'] == 'WARNING':
267
+ logger.warning("⚠ Processed data requires attention")
268
+
269
+ # Step 10: Data splitting
270
+ logger.info("\n" + "="*80)
271
+ logger.info("STEP 10: DATA SPLITTING")
272
+ logger.info("="*80)
273
+
274
+ train_data, val_data, test_data = self.data_splitter.split(
275
+ data,
276
+ method=self.config.split_method,
277
+ test_size=self.config.test_size,
278
+ validation_size=self.config.validation_size
279
+ )
280
+
281
+ self.train_data = train_data
282
+ self.val_data = val_data
283
+ self.test_data = test_data
284
+
285
+ self.results['data_splitting'] = self.data_splitter.split_info
286
+
287
+ # Step 11: Data scaling
288
+ logger.info("\n" + "="*80)
289
+ logger.info("STEP 11: DATA SCALING")
290
+ logger.info("="*80)
291
+
292
+ # Scale training data
293
+ train_data_scaled = self.data_scaler.fit_transform(
294
+ train_data,
295
+ method=self.config.scaling_method,
296
+ target_col=self.config.target_column,
297
+ fit_on_train=True
298
+ )
299
+
300
+ # Apply same scaling to validation and test data
301
+ val_data_scaled = self.data_scaler.transform(val_data)
302
+ test_data_scaled = self.data_scaler.transform(test_data)
303
+
304
+ self.train_data = train_data_scaled
305
+ self.val_data = val_data_scaled
306
+ self.test_data = test_data_scaled
307
+
308
+ self.results['data_scaling'] = self.data_scaler.get_report()
309
+
310
+ # Step 12: Feature selection
311
+ logger.info("\n" + "="*80)
312
+ logger.info("STEP 12: FEATURE SELECTION")
313
+ logger.info("="*80)
314
+
315
+ if len(train_data_scaled.columns) > 5:
316
+ # Select features on training data
317
+ train_data_selected = self.feature_selector.select(
318
+ train_data_scaled,
319
+ method=self.config.feature_selection_method,
320
+ n_features=min(self.config.max_features, len(train_data_scaled.columns) - 1)
321
+ )
322
+
323
+ # Save selected features
324
+ selected_features = self.feature_selector.selected_features
325
+
326
+ # Apply same selection to validation and test data
327
+ features_to_keep = selected_features + [self.config.target_column]
328
+ features_to_keep = [f for f in features_to_keep if f in val_data_scaled.columns]
329
+
330
+ if len(features_to_keep) > 1:
331
+ self.train_data = train_data_scaled[features_to_keep].copy()
332
+ self.val_data = val_data_scaled[features_to_keep].copy()
333
+ self.test_data = test_data_scaled[features_to_keep].copy()
334
+ else:
335
+ logger.warning("Failed to select features, using all")
336
+
337
+ self.results['feature_selection'] = self.feature_selector.get_report()
338
+ else:
339
+ logger.info("Skipped: insufficient features for selection")
340
+ self.results['feature_selection'] = {'skipped': 'insufficient features'}
341
+
342
+ # Step 13: Final validation
343
+ logger.info("\n" + "="*80)
344
+ logger.info("STEP 13: FINAL VALIDATION")
345
+ logger.info("="*80)
346
+
347
+ # Combine all data for final validation
348
+ all_processed_data = pd.concat([self.train_data, self.val_data, self.test_data])
349
+
350
+ final_validation = self.data_validator.validate(
351
+ all_processed_data, stage='final', detailed=True
352
+ )
353
+ self.results['final_validation'] = final_validation
354
+
355
+ self.processed_data = all_processed_data
356
+ self.is_fitted = True
357
+
358
+ # Step 14: Additional multicollinearity cleaning
359
+ logger.info("\n" + "="*80)
360
+ logger.info("STEP 14: ADDITIONAL MULTICOLLINEARITY CLEANING")
361
+ logger.info("="*80)
362
+
363
+ # Remove temporal features with extreme VIF
364
+ self.processed_data = self._remove_extreme_vif_features(self.processed_data)
365
+ self.train_data = self.train_data[self.processed_data.columns]
366
+ self.val_data = self.val_data[self.processed_data.columns]
367
+ self.test_data = self.test_data[self.processed_data.columns]
368
+
369
+ # Step 15: Create visualisations and reports
370
+ logger.info("\n" + "="*80)
371
+ logger.info("STEP 15: CREATING REPORTS AND VISUALISATIONS")
372
+ logger.info("="*80)
373
+
374
+ if create_reports:
375
+ self.create_all_reports()
376
+ self.create_all_visualisations()
377
+
378
+ # Calculate execution time
379
+ execution_time = (datetime.now() - start_time).total_seconds()
380
+
381
+ # Save final results
382
+ self.results['pipeline_execution'] = {
383
+ 'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'),
384
+ 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
385
+ 'execution_time_seconds': execution_time,
386
+ 'success': True,
387
+ 'stages_completed': 15
388
+ }
389
+
390
+ # Save configuration and results
391
+ self.save_pipeline_results()
392
+
393
+ logger.info("\n" + "="*80)
394
+ logger.info("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
395
+ logger.info("="*80)
396
+ logger.info(f"Execution time: {execution_time:.2f} seconds")
397
+ logger.info(f"Initial data size: {self.results['data_loading']['shape']}")
398
+ logger.info(f"Final data size: {list(self.processed_data.shape)}")
399
+ logger.info(f"Data quality: {final_validation['overall_score']}/100")
400
+ logger.info(f"Status: {final_validation['status']}")
401
+ logger.info(f"Training data: {len(self.train_data)} records")
402
+ logger.info(f"Features in final set: {len(self.train_data.columns)}")
403
+
404
+ return self.processed_data
405
+
406
+ except Exception as e:
407
+ logger.error(f"✗ Pipeline error: {e}")
408
+ logger.error(traceback.format_exc())
409
+
410
+ self.results['pipeline_execution'] = {
411
+ 'success': False,
412
+ 'error': str(e),
413
+ 'traceback': traceback.format_exc()
414
+ }
415
+
416
+ # Save partial results
417
+ self.save_pipeline_results()
418
+
419
+ return None
420
+
421
+ def _remove_extreme_vif_features(self, data: pd.DataFrame) -> pd.DataFrame:
422
+ """Remove features with extreme VIF"""
423
+ data_clean = data.copy()
424
+
425
+ # Identify features with extreme VIF for removal
426
+ extreme_vif_features = [
427
+ 'year', 'day', 'dayofyear', 'days_from_start',
428
+ 'raskhodvoda_zscore' # Usually has extreme VIF
429
+ ]
430
+
431
+ # Remove only those present in data
432
+ features_to_remove = [f for f in extreme_vif_features if f in data_clean.columns]
433
+
434
+ if features_to_remove:
435
+ logger.info(f"Removing features with extreme VIF: {features_to_remove}")
436
+ data_clean = data_clean.drop(columns=features_to_remove)
437
+
438
+ return data_clean
439
+
440
+ def create_all_reports(self) -> None:
441
+ """Create all reports"""
442
+ logger.info("Creating reports...")
443
+
444
+ # 1. Save validation results
445
+ for stage in ['raw', 'processed', 'final']:
446
+ if stage in self.data_validator.validation_results:
447
+ self.data_validator.save_report(stage)
448
+
449
+ # 2. Save plots information
450
+ self.visualisation_manager.save_plots_info()
451
+
452
+ # 3. Create summary report
453
+ self.create_summary_report()
454
+
455
+ logger.info("✓ All reports created")
456
+
457
+ def create_all_visualisations(self) -> None:
458
+ """Create all visualisations"""
459
+ logger.info("Creating visualisations...")
460
+
461
+ if self.processed_data is not None:
462
+ # 1. Summary dashboard
463
+ preprocessing_stages = {
464
+ 'Loading': self.results['data_loading']['shape'][1] if 'data_loading' in self.results else 0,
465
+ 'After cleaning': len(self.processed_data.columns),
466
+ 'Features created': self.feature_engineer.feature_info.get('features_created', 0),
467
+ 'Features selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
468
+ }
469
+
470
+ self.visualisation_manager.create_summary_dashboard(
471
+ self.processed_data,
472
+ preprocessing_stages
473
+ )
474
+
475
+ logger.info("✓ All visualisations created")
476
+
477
+ def create_summary_report(self) -> None:
478
+ """Create summary report"""
479
+ report = {
480
+ 'pipeline_summary': {
481
+ 'config': self.config.to_dict(),
482
+ 'execution': self.results.get('pipeline_execution', {}),
483
+ 'data_statistics': {
484
+ 'initial_shape': self.results.get('data_loading', {}).get('shape', []),
485
+ 'final_shape': list(self.processed_data.shape) if self.processed_data is not None else [],
486
+ 'target_column': self.config.target_column,
487
+ 'features_created': self.feature_engineer.feature_info.get('features_created', 0),
488
+ 'features_selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
489
+ }
490
+ },
491
+ 'validation_summary': {},
492
+ 'quality_metrics': {}
493
+ }
494
+
495
+ # Add validation results
496
+ for stage in ['raw', 'processed', 'final']:
497
+ if stage in self.data_validator.validation_results:
498
+ stage_results = self.data_validator.validation_results[stage]
499
+ report['validation_summary'][stage] = {
500
+ 'status': stage_results.get('status'),
501
+ 'score': stage_results.get('overall_score'),
502
+ 'issues_count': sum(len(issues) for issues in stage_results.get('issues', {}).values()),
503
+ 'checks_passed': sum(1 for check in stage_results.get('basic_checks', {}).values()
504
+ if check.get('passed', False))
505
+ }
506
+
507
+ # Save report
508
+ report_path = f'{self.config.results_dir}/reports/pipeline_summary.json'
509
+
510
+ with open(report_path, 'w', encoding='utf-8') as f:
511
+ json.dump(report, f, indent=4, ensure_ascii=False)
512
+
513
+ logger.info(f"✓ Summary report saved: {report_path}")
514
+
515
+ def save_pipeline_results(self) -> None:
516
+ """Save all pipeline results"""
517
+ # Save configuration
518
+ self.config.save()
519
+
520
+ # Save data
521
+ if self.processed_data is not None:
522
+ # Save processed data
523
+ data_path = f'{self.config.results_dir}/processed_data/processed_data.csv'
524
+ self.processed_data.to_csv(data_path)
525
+ logger.info(f"✓ Processed data saved: {data_path}")
526
+
527
+ # Save split data
528
+ if self.train_data is not None:
529
+ self.train_data.to_csv(f'{self.config.results_dir}/processed_data/train_data.csv')
530
+ self.val_data.to_csv(f'{self.config.results_dir}/processed_data/val_data.csv')
531
+ self.test_data.to_csv(f'{self.config.results_dir}/processed_data/test_data.csv')
532
+
533
+ def get_final_data_for_modelling(self) -> Dict[str, Any]:
534
+ """Prepare data for modelling"""
535
+ if not self.is_fitted:
536
+ logger.warning("Pipeline not executed, data not ready")
537
+ return {}
538
+
539
+ return {
540
+ 'X_train': self.train_data.drop(columns=[self.config.target_column]),
541
+ 'y_train': self.train_data[self.config.target_column],
542
+ 'X_val': self.val_data.drop(columns=[self.config.target_column]),
543
+ 'y_val': self.val_data[self.config.target_column],
544
+ 'X_test': self.test_data.drop(columns=[self.config.target_column]),
545
+ 'y_test': self.test_data[self.config.target_column],
546
+ 'feature_names': self.train_data.drop(columns=[self.config.target_column]).columns.tolist(),
547
+ 'scaler': self.data_scaler,
548
+ 'feature_selector': self.feature_selector,
549
+ 'results': self.results
550
+ }
551
+
552
+
553
+ # ============================================
554
+ # QUICK LAUNCH FUNCTION
555
+ # ============================================
556
+ def run_enhanced_preprocessing(
557
+ config_path: Optional[str] = None,
558
+ data_path: Optional[str] = None,
559
+ use_synthetic: bool = False,
560
+ save_results: bool = True
561
+ ) -> EnhancedDataPreprocessingPipeline:
562
+ """
563
+ Quick launch function for enhanced pipeline
564
+
565
+ Parameters:
566
+ -----------
567
+ config_path : str, optional
568
+ Path to configuration file
569
+ data_path : str, optional
570
+ Path to data
571
+ use_synthetic : bool
572
+ Use synthetic data
573
+ save_results : bool
574
+ Save results
575
+
576
+ Returns:
577
+ --------
578
+ EnhancedDataPreprocessingPipeline
579
+ Pipeline object with results
580
+ """
581
+ # Load or create configuration
582
+ if config_path and os.path.exists(config_path):
583
+ config = Config.load(config_path)
584
+ logger.info(f"Configuration loaded from {config_path}")
585
+ else:
586
+ config = Config()
587
+ logger.info("Using default configuration")
588
+
589
+ # Update data path if specified
590
+ if data_path:
591
+ config.data_path = data_path
592
+
593
+ # Create and run pipeline
594
+ pipeline = EnhancedDataPreprocessingPipeline(config)
595
+
596
+ pipeline.run_full_pipeline(
597
+ data_path=data_path,
598
+ use_synthetic=use_synthetic,
599
+ save_intermediate=save_results,
600
+ create_reports=save_results
601
+ )
602
+
603
+ return pipeline
requirements.txt CHANGED
@@ -1,3 +1,100 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ altair==6.0.0
3
+ anyio==4.11.0
4
+ astunparse==1.6.3
5
+ attrs==25.4.0
6
+ blinker==1.9.0
7
+ cachetools==6.2.4
8
+ certifi==2025.11.12
9
+ charset-normalizer==3.4.4
10
+ click==8.3.1
11
+ colorama==0.4.6
12
+ contourpy==1.3.2
13
+ cycler==0.12.1
14
+ et_xmlfile==2.0.0
15
+ filelock==3.20.1
16
+ flatbuffers==25.12.19
17
+ fonttools==4.61.1
18
+ fsspec==2025.12.0
19
+ gast==0.7.0
20
+ gensim==4.4.0
21
+ gitdb==4.0.12
22
+ GitPython==3.1.46
23
+ google-pasta==0.2.0
24
+ grpcio==1.76.0
25
+ h5py==3.15.1
26
+ h11==0.16.0
27
+ hf-xet==1.2.0
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ huggingface_hub==1.1.2
31
+ idna==3.11
32
+ Jinja2==3.1.6
33
+ joblib==1.5.3
34
+ jsonschema==4.25.1
35
+ jsonschema-specifications==2025.9.1
36
+ keras==3.13.0
37
+ kiwisolver==1.4.9
38
+ libclang==18.1.1
39
+ Markdown==3.10
40
+ markdown-it-py==4.0.0
41
+ MarkupSafe==3.0.3
42
+ matplotlib==3.10.8
43
+ mdurl==0.1.2
44
+ ml_dtypes==0.5.4
45
+ mpmath==1.3.0
46
+ namex==0.1.0
47
+ narwhals==2.14.0
48
+ networkx==3.6.1
49
+ numpy==2.4.0
50
+ openpyxl==3.1.5
51
+ opt_einsum==3.4.0
52
+ optree==0.18.0
53
+ packaging==25.0
54
+ pandas==2.3.3
55
+ patsy==1.0.2
56
+ pillow==12.0.0
57
+ plotly==6.5.0
58
+ protobuf==6.33.2
59
+ pyarrow==22.0.0
60
+ pydeck==0.9.1
61
+ Pygments==2.19.2
62
+ pyparsing==3.3.1
63
+ pyperclip==1.11.0
64
+ python-dateutil==2.9.0.post0
65
+ pytz==2025.2
66
+ PyYAML==6.0.3
67
+ referencing==0.37.0
68
+ requests==2.32.5
69
+ rich==14.2.0
70
+ rpds-py==0.30.0
71
+ scikit-learn==1.8.0
72
+ scipy==1.16.3
73
+ seaborn==0.13.2
74
+ setuptools==80.9.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ smart_open==7.5.0
78
+ smmap==5.0.2
79
+ sniffio==1.3.1
80
+ statsmodels==0.14.6
81
+ streamlit==1.52.2
82
+ sympy==1.14.0
83
+ tenacity==9.1.2
84
+ tensorboard==2.20.0
85
+ tensorboard-data-server==0.7.2
86
+ tensorflow==2.20.0
87
+ termcolor==3.3.0
88
+ threadpoolctl==3.6.0
89
+ toml==0.10.2
90
+ torch==2.9.1
91
+ tornado==6.5.4
92
+ tqdm==4.67.1
93
+ typer-slim==0.20.0
94
+ typing_extensions==4.15.0
95
+ tzdata==2025.3
96
+ urllib3==2.6.2
97
+ watchdog==6.0.0
98
+ Werkzeug==3.1.4
99
+ wheel==0.45.1
100
+ wrapt==2.0.1
run_pipeline.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # RUN
3
+ # ============================================
4
+ from config.config import Config
5
+ from pipeline.main_pipeline import EnhancedDataPreprocessingPipeline
6
+ import pandas as pd
7
+
8
+ if __name__ == "__main__":
9
+ """
10
+ Pipeline execution
11
+ """
12
+
13
+ # Configuration with reasonable parameters
14
+ config = Config(
15
+ data_path='temp_data.csv',
16
+ results_dir='enhanced_preprocessing_results',
17
+ target_column='raskhodvoda',
18
+ start_year=1970,
19
+ end_year=1990,
20
+ max_lags=5,
21
+ seasonal_period=365,
22
+ rolling_windows=[7, 30, 90],
23
+ expanding_windows=[30, 90],
24
+ test_size=0.2,
25
+ validation_size=0.1,
26
+ scaling_method='robust',
27
+ feature_selection_method='correlation',
28
+ max_features=20,
29
+ missing_threshold=0.3,
30
+ outlier_method='iqr',
31
+ enable_validation=True
32
+ )
33
+
34
+ # Run enhanced pipeline
35
+ pipeline = EnhancedDataPreprocessingPipeline(config)
36
+ processed_data = pipeline.run_full_pipeline(
37
+ use_synthetic=False,
38
+ save_intermediate=True,
39
+ create_reports=True
40
+ )
41
+
42
+ if processed_data is not None:
43
+ print("\n" + "="*80)
44
+ print("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
45
+ print("="*80)
46
+ print(f"Final data size: {processed_data.shape}")
47
+ print(f"Columns: {list(processed_data.columns)}")
48
+
49
+ # Get modeling data
50
+ modeling_data = pipeline.get_final_data_for_modeling()
51
+
52
+ if modeling_data:
53
+ print(f"\nModeling data ready:")
54
+ print(f" X_train: {modeling_data['X_train'].shape}")
55
+ print(f" X_val: {modeling_data['X_val'].shape}")
56
+ print(f" X_test: {modeling_data['X_test'].shape}")
57
+ print(f" Features: {len(modeling_data['feature_names'])}")
58
+
59
+ # Save final data
60
+ processed_data.to_csv('enhanced_preprocessing_results\processed_data\enhanced_final_processed_data.csv',
61
+ index=True if isinstance(processed_data.index, pd.DatetimeIndex) else False)
62
+ print(f"\n✓ Final data saved to 'enhanced_final_processed_data.csv'")
scaling/__init__.py ADDED
File without changes
scaling/data_scaler.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 10: DATA SCALING
3
+ # ============================================
4
+ from typing import Dict, List, Optional, Tuple
5
+ from venv import logger
6
+ import pandas as pd
7
+ from config.config import Config
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+ class DataScaler:
13
+ """Class for data scaling and normalisation"""
14
+
15
+ def __init__(self, config: Config):
16
+ """
17
+ Initialise scaler
18
+
19
+ Parameters:
20
+ -----------
21
+ config : Config
22
+ Experiment configuration
23
+ """
24
+ self.config = config
25
+ self.scalers = {}
26
+ self.scaling_info = {}
27
+ self.transforms_applied = {}
28
+
29
+ def fit_transform(
30
+ self,
31
+ data: pd.DataFrame,
32
+ method: str = None,
33
+ columns: List[str] = None,
34
+ target_col: Optional[str] = None,
35
+ fit_on_train: bool = True,
36
+ **kwargs
37
+ ) -> pd.DataFrame:
38
+ """
39
+ Scale data
40
+
41
+ Parameters:
42
+ -----------
43
+ data : pd.DataFrame
44
+ Input data
45
+ method : str, optional
46
+ Scaling method. If None, uses configuration value.
47
+ columns : List[str], optional
48
+ List of columns to scale. If None, uses all numeric columns.
49
+ target_col : str, optional
50
+ Target variable (not scaled by default)
51
+ fit_on_train : bool
52
+ Whether to save scaling parameters for applying to new data
53
+ **kwargs : dict
54
+ Additional parameters for method
55
+
56
+ Returns:
57
+ --------
58
+ pd.DataFrame
59
+ Scaled data
60
+ """
61
+ logger.info("\n" + "="*80)
62
+ logger.info("DATA SCALING")
63
+ logger.info("="*80)
64
+
65
+ method = method or self.config.scaling_method
66
+ data_scaled = data.copy()
67
+
68
+ if columns is None:
69
+ # Select all numeric columns except target
70
+ numeric_cols = data_scaled.select_dtypes(include=[np.number]).columns
71
+ if target_col and target_col in numeric_cols:
72
+ columns = [col for col in numeric_cols if col != target_col]
73
+ else:
74
+ columns = list(numeric_cols)
75
+
76
+ logger.info(f"Scaling method: {method}")
77
+ logger.info(f"Columns to process: {len(columns)}")
78
+
79
+ # Apply scaling
80
+ for col in columns:
81
+ if col in data_scaled.columns:
82
+ try:
83
+ # Check feature type
84
+ series = data_scaled[col].dropna()
85
+
86
+ # Special handling for different feature types
87
+ if self._is_binary_feature(series):
88
+ logger.debug(f" {col}: binary feature, scaling not applied")
89
+ scaler_info = {
90
+ 'method': 'none',
91
+ 'scaler_type': 'binary',
92
+ 'original_values': sorted(series.unique().tolist()),
93
+ 'note': 'binary feature, no scaling applied'
94
+ }
95
+ self.scaling_info[col] = scaler_info
96
+
97
+ if fit_on_train:
98
+ self.scalers[col] = scaler_info
99
+
100
+ elif self._is_categorical_feature(series):
101
+ logger.debug(f" {col}: categorical feature, using min-max")
102
+ scaled_series, scaler_info = self._apply_scaling(
103
+ data_scaled[col], 'minmax', fit_on_train, **kwargs
104
+ )
105
+ data_scaled[col] = scaled_series
106
+ self.scaling_info[col] = scaler_info
107
+
108
+ if fit_on_train:
109
+ if scaler_info.get('scaler_type') == 'sklearn':
110
+ self.scalers[col] = scaler_info['scaler_object']
111
+ else:
112
+ self.scalers[col] = scaler_info
113
+
114
+ else:
115
+ # Regular scaling for continuous features
116
+ scaled_series, scaler_info = self._apply_scaling(
117
+ data_scaled[col], method, fit_on_train, **kwargs
118
+ )
119
+ data_scaled[col] = scaled_series
120
+ self.scaling_info[col] = scaler_info
121
+
122
+ if fit_on_train:
123
+ if scaler_info.get('scaler_type') == 'sklearn':
124
+ self.scalers[col] = scaler_info['scaler_object']
125
+ else:
126
+ self.scalers[col] = scaler_info
127
+
128
+ except Exception as e:
129
+ logger.warning(f"Error processing column {col}: {e}")
130
+ # Save error information
131
+ self.scaling_info[col] = {
132
+ 'method': 'error',
133
+ 'error': str(e),
134
+ 'scaler_type': 'none'
135
+ }
136
+
137
+ logger.info(f"✓ Data processed using {method} method")
138
+
139
+ # Visualisation of results
140
+ if self.config.save_plots and columns:
141
+ self._plot_scaling_results(data, data_scaled, columns, method)
142
+
143
+ return data_scaled
144
+
145
+ def _is_binary_feature(self, series: pd.Series) -> bool:
146
+ """Check if feature is binary"""
147
+ unique_values = series.dropna().unique()
148
+ return len(unique_values) == 2 and set(unique_values).issubset({0, 1})
149
+
150
+ def _is_categorical_feature(self, series: pd.Series, max_categories: int = 10) -> bool:
151
+ """Check if feature is categorical"""
152
+ unique_values = series.dropna().unique()
153
+ return len(unique_values) <= max_categories and series.dtype in ['int64', 'float64']
154
+
155
+ def _apply_scaling(
156
+ self,
157
+ series: pd.Series,
158
+ method: str,
159
+ fit_on_train: bool,
160
+ **kwargs
161
+ ) -> Tuple[pd.Series, Dict]:
162
+ """Apply specific scaling method"""
163
+ series_clean = series.dropna()
164
+
165
+ if len(series_clean) == 0:
166
+ return series, {
167
+ 'method': 'none',
168
+ 'scaler_type': 'none',
169
+ 'error': 'all values are NaN'
170
+ }
171
+
172
+ scaler_info = {
173
+ 'method': method,
174
+ 'scaler_type': 'simple',
175
+ 'original_mean': float(series_clean.mean()),
176
+ 'original_std': float(series_clean.std()),
177
+ 'original_min': float(series_clean.min()),
178
+ 'original_max': float(series_clean.max()),
179
+ 'scaler': None,
180
+ 'scaler_object': None
181
+ }
182
+
183
+ try:
184
+ if method == 'standard':
185
+ # Standardisation (z-score normalisation)
186
+ mean = series_clean.mean()
187
+ std = series_clean.std()
188
+
189
+ if std > 0:
190
+ series_scaled = (series - mean) / std
191
+ scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
192
+ else:
193
+ series_scaled = series - mean # If std = 0, just center
194
+ scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
195
+
196
+ elif method == 'minmax':
197
+ # Min-Max normalisation
198
+ min_val = series_clean.min()
199
+ max_val = series_clean.max()
200
+
201
+ if max_val > min_val:
202
+ series_scaled = (series - min_val) / (max_val - min_val)
203
+ scaler_info['scaler'] = {'min': float(min_val), 'max': float(max_val)}
204
+ else:
205
+ series_scaled = series - min_val # If all values equal
206
+ scaler_info['scaler'] = {'min': float(min_val), 'max': float(min_val)}
207
+
208
+ elif method == 'robust':
209
+ # Robust scaling (outlier resistant)
210
+ # Check sufficient values for quartile calculation
211
+ if len(series_clean) >= 4:
212
+ median = series_clean.median()
213
+ q1 = series_clean.quantile(0.25)
214
+ q3 = series_clean.quantile(0.75)
215
+ iqr = q3 - q1
216
+
217
+ if iqr > 0:
218
+ series_scaled = (series - median) / iqr
219
+ scaler_info['scaler'] = {
220
+ 'median': float(median),
221
+ 'q1': float(q1),
222
+ 'q3': float(q3),
223
+ 'iqr': float(iqr)
224
+ }
225
+ else:
226
+ # If IQR = 0, use standard deviation
227
+ std = series_clean.std()
228
+ if std > 0:
229
+ series_scaled = (series - median) / std
230
+ scaler_info['scaler'] = {'median': float(median), 'std': float(std)}
231
+ else:
232
+ series_scaled = series - median
233
+ scaler_info['scaler'] = {'median': float(median), 'iqr': 0}
234
+ else:
235
+ # If insufficient data, use standardisation
236
+ mean = series_clean.mean()
237
+ std = series_clean.std()
238
+ if std > 0:
239
+ series_scaled = (series - mean) / std
240
+ scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
241
+ scaler_info['method'] = 'standard' # Change method in info
242
+ else:
243
+ series_scaled = series - mean
244
+ scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
245
+ scaler_info['method'] = 'standard'
246
+
247
+ elif method == 'log':
248
+ # Logarithmic transformation
249
+ min_val = series_clean.min()
250
+
251
+ if min_val <= 0:
252
+ shift = abs(min_val) + 1
253
+ series_scaled = np.log(series + shift)
254
+ scaler_info['scaler'] = {'shift': float(shift)}
255
+ else:
256
+ series_scaled = np.log(series)
257
+ scaler_info['scaler'] = {'shift': 0}
258
+
259
+ elif method == 'boxcox':
260
+ # Box-Cox transformation
261
+ try:
262
+ from scipy.stats import boxcox
263
+
264
+ min_val = series_clean.min()
265
+ if min_val <= 0:
266
+ shift = abs(min_val) + 1
267
+ series_to_transform = series + shift
268
+ else:
269
+ shift = 0
270
+ series_to_transform = series
271
+
272
+ transformed, lambda_val = boxcox(series_to_transform.dropna())
273
+
274
+ # Interpolate for all values
275
+ series_scaled = series.copy()
276
+ valid_mask = series_to_transform.notna()
277
+ series_scaled[valid_mask] = transformed
278
+
279
+ scaler_info['scaler'] = {
280
+ 'lambda': float(lambda_val),
281
+ 'shift': float(shift)
282
+ }
283
+
284
+ except Exception as e:
285
+ logger.warning(f"Box-Cox transformation failed for {series.name}: {e}")
286
+ # Return original series and change method
287
+ series_scaled = series
288
+ scaler_info['method'] = 'none'
289
+ scaler_info['scaler_type'] = 'none'
290
+ scaler_info['error'] = str(e)
291
+
292
+ elif method == 'quantile':
293
+ # Quantile transformation (rank-based)
294
+ try:
295
+ from sklearn.preprocessing import QuantileTransformer
296
+
297
+ qt = QuantileTransformer(
298
+ n_quantiles=kwargs.get('n_quantiles', min(100, len(series_clean))),
299
+ output_distribution=kwargs.get('output_distribution', 'normal'),
300
+ random_state=kwargs.get('random_state', 42)
301
+ )
302
+
303
+ series_reshaped = series.values.reshape(-1, 1)
304
+ series_scaled_values = qt.fit_transform(series_reshaped)
305
+ series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
306
+
307
+ scaler_info['scaler_type'] = 'sklearn'
308
+ scaler_info['scaler_object'] = qt
309
+
310
+ except Exception as e:
311
+ logger.warning(f"Quantile transform failed for {series.name}: {e}")
312
+ series_scaled = series
313
+ scaler_info['method'] = 'none'
314
+ scaler_info['scaler_type'] = 'none'
315
+ scaler_info['error'] = str(e)
316
+
317
+ elif method == 'power':
318
+ # Power transform (Yeo-Johnson)
319
+ try:
320
+ from sklearn.preprocessing import PowerTransformer
321
+
322
+ pt = PowerTransformer(method='yeo-johnson', standardize=True)
323
+
324
+ series_reshaped = series.values.reshape(-1, 1)
325
+ series_scaled_values = pt.fit_transform(series_reshaped)
326
+ series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
327
+
328
+ scaler_info['scaler_type'] = 'sklearn'
329
+ scaler_info['scaler_object'] = pt
330
+
331
+ except Exception as e:
332
+ logger.warning(f"Power transform failed for {series.name}: {e}")
333
+ series_scaled = series
334
+ scaler_info['method'] = 'none'
335
+ scaler_info['scaler_type'] = 'none'
336
+ scaler_info['error'] = str(e)
337
+
338
+ elif method == 'none':
339
+ # No scaling
340
+ series_scaled = series
341
+ scaler_info['method'] = 'none'
342
+ scaler_info['scaler_type'] = 'none'
343
+
344
+ else:
345
+ logger.warning(f"Unknown scaling method: {method}, using standardisation")
346
+ return self._apply_scaling(series, 'standard', fit_on_train, **kwargs)
347
+
348
+ # Add statistics after scaling
349
+ scaled_clean = series_scaled.dropna()
350
+ if len(scaled_clean) > 0:
351
+ scaler_info.update({
352
+ 'scaled_mean': float(scaled_clean.mean()),
353
+ 'scaled_std': float(scaled_clean.std()),
354
+ 'scaled_min': float(scaled_clean.min()),
355
+ 'scaled_max': float(scaled_clean.max()),
356
+ 'skewness_before': float(series_clean.skew()),
357
+ 'skewness_after': float(scaled_clean.skew()),
358
+ 'kurtosis_before': float(series_clean.kurtosis()),
359
+ 'kurtosis_after': float(scaled_clean.kurtosis())
360
+ })
361
+
362
+ return series_scaled, scaler_info
363
+
364
+ except Exception as e:
365
+ logger.warning(f"Error applying method {method} for {series.name}: {e}")
366
+ return series, {
367
+ 'method': 'error',
368
+ 'scaler_type': 'none',
369
+ 'error': str(e)
370
+ }
371
+
372
+ def transform(
373
+ self,
374
+ data: pd.DataFrame,
375
+ columns: List[str] = None
376
+ ) -> pd.DataFrame:
377
+ """
378
+ Apply saved scaling to new data
379
+
380
+ Parameters:
381
+ -----------
382
+ data : pd.DataFrame
383
+ New data
384
+ columns : List[str], optional
385
+ List of columns to transform
386
+
387
+ Returns:
388
+ --------
389
+ pd.DataFrame
390
+ Transformed data
391
+ """
392
+ if not self.scalers:
393
+ logger.warning("Scalers not trained, use fit_transform first")
394
+ return data
395
+
396
+ data_transformed = data.copy()
397
+
398
+ if columns is None:
399
+ columns = [col for col in self.scalers.keys() if col in data_transformed.columns]
400
+
401
+ for col in columns:
402
+ if col in data_transformed.columns and col in self.scalers:
403
+ try:
404
+ scaler_info = self.scaling_info.get(col, {})
405
+ scaler_data = self.scalers[col]
406
+ method = scaler_info.get('method', 'unknown')
407
+
408
+ # For binary features, do nothing
409
+ if method == 'none' and scaler_info.get('scaler_type') == 'binary':
410
+ continue
411
+
412
+ # Skip errors
413
+ if method == 'error':
414
+ continue
415
+
416
+ if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
417
+ scaler_params = scaler_data['scaler']
418
+
419
+ if method == 'standard':
420
+ mean = scaler_params.get('mean', 0)
421
+ std = scaler_params.get('std', 1)
422
+ if std > 0:
423
+ data_transformed[col] = (data_transformed[col] - mean) / std
424
+
425
+ elif method == 'minmax':
426
+ min_val = scaler_params.get('min', 0)
427
+ max_val = scaler_params.get('max', 1)
428
+ if max_val > min_val:
429
+ data_transformed[col] = (data_transformed[col] - min_val) / (max_val - min_val)
430
+
431
+ elif method == 'robust':
432
+ median = scaler_params.get('median', 0)
433
+ iqr = scaler_params.get('iqr', 1)
434
+ if iqr > 0:
435
+ data_transformed[col] = (data_transformed[col] - median) / iqr
436
+ else:
437
+ std = scaler_params.get('std', 1)
438
+ if std > 0:
439
+ data_transformed[col] = (data_transformed[col] - median) / std
440
+
441
+ elif hasattr(scaler_data, 'transform'):
442
+ # For sklearn objects
443
+ from sklearn.base import BaseEstimator
444
+ if isinstance(scaler_data, BaseEstimator):
445
+ try:
446
+ transformed = scaler_data.transform(
447
+ data_transformed[[col]].values.reshape(-1, 1)
448
+ ).flatten()
449
+ data_transformed[col] = transformed
450
+ except Exception as e:
451
+ logger.warning(f"Error in sklearn transformation for {col}: {e}")
452
+
453
+ except Exception as e:
454
+ logger.warning(f"Error transforming column {col}: {e}")
455
+
456
+ return data_transformed
457
+
458
+ def inverse_transform(
459
+ self,
460
+ data: pd.DataFrame,
461
+ columns: List[str] = None
462
+ ) -> pd.DataFrame:
463
+ """
464
+ Inverse transform scaled data
465
+
466
+ Parameters:
467
+ -----------
468
+ data : pd.DataFrame
469
+ Scaled data
470
+ columns : List[str], optional
471
+ List of columns for inverse transform
472
+
473
+ Returns:
474
+ --------
475
+ pd.DataFrame
476
+ Data in original scale
477
+ """
478
+ if not self.scalers:
479
+ logger.warning("Scalers not trained")
480
+ return data
481
+
482
+ data_inverse = data.copy()
483
+
484
+ if columns is None:
485
+ columns = [col for col in self.scalers.keys() if col in data_inverse.columns]
486
+
487
+ for col in columns:
488
+ if col in data_inverse.columns and col in self.scalers:
489
+ try:
490
+ scaler_info = self.scaling_info.get(col, {})
491
+ scaler_data = self.scalers[col]
492
+ method = scaler_info.get('method', 'unknown')
493
+
494
+ # For binary and categorical features, do nothing
495
+ if method in ['none', 'error']:
496
+ continue
497
+
498
+ if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
499
+ scaler_params = scaler_data['scaler']
500
+
501
+ if method == 'standard':
502
+ mean = scaler_params.get('mean', 0)
503
+ std = scaler_params.get('std', 1)
504
+ data_inverse[col] = data_inverse[col] * std + mean
505
+
506
+ elif method == 'minmax':
507
+ min_val = scaler_params.get('min', 0)
508
+ max_val = scaler_params.get('max', 1)
509
+ if max_val > min_val:
510
+ data_inverse[col] = data_inverse[col] * (max_val - min_val) + min_val
511
+
512
+ elif method == 'robust':
513
+ median = scaler_params.get('median', 0)
514
+ iqr = scaler_params.get('iqr', 1)
515
+ if iqr > 0:
516
+ data_inverse[col] = data_inverse[col] * iqr + median
517
+ else:
518
+ std = scaler_params.get('std', 1)
519
+ if std > 0:
520
+ data_inverse[col] = data_inverse[col] * std + median
521
+
522
+ elif hasattr(scaler_data, 'inverse_transform'):
523
+ # For sklearn objects
524
+ from sklearn.base import BaseEstimator
525
+ if isinstance(scaler_data, BaseEstimator):
526
+ try:
527
+ inverse_transformed = scaler_data.inverse_transform(
528
+ data_inverse[[col]].values.reshape(-1, 1)
529
+ ).flatten()
530
+ data_inverse[col] = inverse_transformed
531
+ except Exception as e:
532
+ logger.warning(f"Error in sklearn inverse transformation for {col}: {e}")
533
+
534
+ except Exception as e:
535
+ logger.warning(f"Error in inverse transformation for column {col}: {e}")
536
+
537
+ return data_inverse
538
+
539
+ def _plot_scaling_results(
540
+ self,
541
+ original_data: pd.DataFrame,
542
+ scaled_data: pd.DataFrame,
543
+ columns: List[str],
544
+ method: str
545
+ ) -> None:
546
+ """Visualise scaling results"""
547
+ # Limit number of columns for visualisation
548
+ cols_to_plot = [col for col in columns if col in original_data.columns and col in scaled_data.columns][:8]
549
+
550
+ if not cols_to_plot:
551
+ return
552
+
553
+ n_cols = 4
554
+ n_rows = (len(cols_to_plot) + n_cols - 1) // n_cols
555
+
556
+ fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(16, 4 * n_rows))
557
+
558
+ for idx, col in enumerate(cols_to_plot):
559
+ row = idx // n_cols
560
+ col_idx = (idx % n_cols) * 2
561
+
562
+ # Distribution before scaling
563
+ axes[row, col_idx].hist(
564
+ original_data[col].dropna(),
565
+ bins=30,
566
+ alpha=0.7,
567
+ color='blue',
568
+ density=True
569
+ )
570
+ axes[row, col_idx].set_title(f'{col} (before)', fontsize=10)
571
+ axes[row, col_idx].set_xlabel('Value')
572
+ axes[row, col_idx].set_ylabel('Density')
573
+ axes[row, col_idx].grid(True, alpha=0.3)
574
+
575
+ # Distribution after scaling
576
+ axes[row, col_idx + 1].hist(
577
+ scaled_data[col].dropna(),
578
+ bins=30,
579
+ alpha=0.7,
580
+ color='green',
581
+ density=True
582
+ )
583
+ axes[row, col_idx + 1].set_title(f'{col} (after)', fontsize=10)
584
+ axes[row, col_idx + 1].set_xlabel('Scaled value')
585
+ axes[row, col_idx + 1].set_ylabel('Density')
586
+ axes[row, col_idx + 1].grid(True, alpha=0.3)
587
+
588
+ # Hide unused subplots
589
+ total_plots = n_rows * n_cols * 2
590
+ for idx in range(len(cols_to_plot) * 2, total_plots):
591
+ row = idx // (n_cols * 2)
592
+ col_idx = idx % (n_cols * 2)
593
+ axes[row, col_idx].set_visible(False)
594
+
595
+ plt.suptitle(f'Scaling results using {method} method', fontsize=14)
596
+ plt.tight_layout()
597
+ plt.savefig(
598
+ f'{self.config.results_dir}/plots/scaling_results.png',
599
+ dpi=300,
600
+ bbox_inches='tight'
601
+ )
602
+ plt.show()
603
+
604
+ def get_report(self) -> Dict:
605
+ """Get scaling report"""
606
+ summary = {
607
+ 'total_columns': len(self.scaling_info),
608
+ 'methods_used': {},
609
+ 'binary_features': [],
610
+ 'categorical_features': [],
611
+ 'continuous_features': [],
612
+ 'errors': []
613
+ }
614
+
615
+ for col, info in self.scaling_info.items():
616
+ method = info.get('method', 'unknown')
617
+ if method not in summary['methods_used']:
618
+ summary['methods_used'][method] = 0
619
+ summary['methods_used'][method] += 1
620
+
621
+ if method == 'none' and info.get('scaler_type') == 'binary':
622
+ summary['binary_features'].append(col)
623
+ elif method in ['minmax', 'standard', 'robust']:
624
+ summary['continuous_features'].append(col)
625
+ elif method == 'error':
626
+ summary['errors'].append({
627
+ 'column': col,
628
+ 'error': info.get('error', 'unknown')
629
+ })
630
+
631
+ return {
632
+ 'summary': summary,
633
+ 'details': self.scaling_info
634
+ }
splitting/__init__.py ADDED
File without changes
splitting/data_splitter.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 9: DATA SPLITTING
3
+ # ============================================
4
+ from datetime import datetime
5
+ from typing import Dict, Optional, Tuple
6
+ from venv import logger
7
+ import pandas as pd
8
+ from config.config import Config
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ class DataSplitter:
14
+ """Class for splitting data into train, validation and test sets"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise data splitter
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.split_info = {}
27
+ self.split_indices = {}
28
+ self.split_strategy = None
29
+
30
+ def split(
31
+ self,
32
+ data: pd.DataFrame,
33
+ test_size: Optional[float] = None,
34
+ validation_size: Optional[float] = None,
35
+ method: str = None,
36
+ random_state: int = 42,
37
+ **kwargs
38
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
39
+ """
40
+ Split data into train, validation and test sets
41
+
42
+ Parameters:
43
+ -----------
44
+ data : pd.DataFrame
45
+ Input data
46
+ test_size : float, optional
47
+ Test set size. If None, uses configuration value.
48
+ validation_size : float, optional
49
+ Validation set size. If None, uses configuration value.
50
+ method : str, optional
51
+ Splitting method: 'time', 'random', 'expanding_window', 'sliding_window'
52
+ random_state : int
53
+ Seed for reproducibility
54
+ **kwargs : dict
55
+ Additional parameters for method
56
+
57
+ Returns:
58
+ --------
59
+ Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
60
+ Train, validation and test data
61
+ """
62
+ logger.info("\n" + "="*80)
63
+ logger.info("DATA SPLITTING")
64
+ logger.info("="*80)
65
+
66
+ test_size = test_size or self.config.test_size
67
+ validation_size = validation_size or self.config.validation_size
68
+ method = method or self.config.split_method
69
+
70
+ n = len(data)
71
+
72
+ logger.info(f"Total data: {n} records")
73
+ logger.info(f"Splitting method: {method}")
74
+ logger.info(f"Sizes: train={1-test_size-validation_size:.1%}, val={validation_size:.1%}, test={test_size:.1%}")
75
+
76
+ if method == 'time':
77
+ train_data, val_data, test_data = self._time_based_split(
78
+ data, test_size, validation_size
79
+ )
80
+ elif method == 'random':
81
+ train_data, val_data, test_data = self._random_split(
82
+ data, test_size, validation_size, random_state
83
+ )
84
+ elif method == 'expanding_window':
85
+ train_data, val_data, test_data = self._expanding_window_split(
86
+ data, test_size, validation_size, **kwargs
87
+ )
88
+ elif method == 'sliding_window':
89
+ train_data, val_data, test_data = self._sliding_window_split(
90
+ data, **kwargs
91
+ )
92
+ else:
93
+ logger.warning(f"Method {method} not supported, using time-based split")
94
+ train_data, val_data, test_data = self._time_based_split(
95
+ data, test_size, validation_size
96
+ )
97
+
98
+ # Save splitting information
99
+ self._save_split_info(data, train_data, val_data, test_data, method)
100
+
101
+ # Output information
102
+ self._log_split_summary(train_data, val_data, test_data)
103
+
104
+ # Visualisation of split
105
+ if self.config.save_plots:
106
+ self._plot_data_split(data, train_data, val_data, test_data)
107
+
108
+ return train_data, val_data, test_data
109
+
110
+ def _time_based_split(
111
+ self,
112
+ data: pd.DataFrame,
113
+ test_size: float,
114
+ validation_size: float
115
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
116
+ """Time-based splitting preserving temporal order"""
117
+ n = len(data)
118
+
119
+ # Calculate set sizes
120
+ test_size_int = int(n * test_size)
121
+ val_size_int = int(n * validation_size)
122
+ train_size_int = n - test_size_int - val_size_int
123
+
124
+ # Split data
125
+ train_data = data.iloc[:train_size_int].copy()
126
+ val_data = data.iloc[train_size_int:train_size_int + val_size_int].copy()
127
+ test_data = data.iloc[train_size_int + val_size_int:].copy()
128
+
129
+ self.split_strategy = 'time_based'
130
+
131
+ return train_data, val_data, test_data
132
+
133
+ def _random_split(
134
+ self,
135
+ data: pd.DataFrame,
136
+ test_size: float,
137
+ validation_size: float,
138
+ random_state: int
139
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
140
+ """Random data splitting"""
141
+ from sklearn.model_selection import train_test_split
142
+
143
+ # First split into train+val and test
144
+ train_val_data, test_data = train_test_split(
145
+ data,
146
+ test_size=test_size,
147
+ random_state=random_state,
148
+ shuffle=True
149
+ )
150
+
151
+ # Then split train+val into train and val
152
+ val_relative_size = validation_size / (1 - test_size)
153
+ train_data, val_data = train_test_split(
154
+ train_val_data,
155
+ test_size=val_relative_size,
156
+ random_state=random_state,
157
+ shuffle=True
158
+ )
159
+
160
+ self.split_strategy = 'random'
161
+
162
+ return train_data, val_data, test_data
163
+
164
+ def _expanding_window_split(
165
+ self,
166
+ data: pd.DataFrame,
167
+ test_size: float,
168
+ validation_size: float,
169
+ **kwargs
170
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
171
+ """Expanding window split"""
172
+ n = len(data)
173
+
174
+ # Minimum initial window size
175
+ initial_window = kwargs.get('initial_window', max(100, int(n * 0.1)))
176
+
177
+ # Final set sizes
178
+ test_size_int = int(n * test_size)
179
+ val_size_int = int(n * validation_size)
180
+
181
+ # Determine boundaries
182
+ test_start = n - test_size_int
183
+ val_start = test_start - val_size_int
184
+
185
+ # For expanding window, use all data up to val_start for training
186
+ train_data = data.iloc[:val_start].copy()
187
+ val_data = data.iloc[val_start:test_start].copy()
188
+ test_data = data.iloc[test_start:].copy()
189
+
190
+ self.split_strategy = 'expanding_window'
191
+
192
+ return train_data, val_data, test_data
193
+
194
+ def _sliding_window_split(
195
+ self,
196
+ data: pd.DataFrame,
197
+ **kwargs
198
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
199
+ """Sliding window split (for multiple train-val-test pairs)"""
200
+ window_size = kwargs.get('window_size', len(data) // 3)
201
+ step = kwargs.get('step', window_size // 2)
202
+
203
+ # For simplicity return single split
204
+ # In real scenarios can return list of splits
205
+ n = len(data)
206
+
207
+ train_end = n - window_size
208
+ val_end = train_end + window_size // 3
209
+ test_end = n
210
+
211
+ train_data = data.iloc[:train_end].copy()
212
+ val_data = data.iloc[train_end:val_end].copy()
213
+ test_data = data.iloc[val_end:].copy()
214
+
215
+ self.split_strategy = 'sliding_window'
216
+
217
+ return train_data, val_data, test_data
218
+
219
+ def _save_split_info(
220
+ self,
221
+ full_data: pd.DataFrame,
222
+ train_data: pd.DataFrame,
223
+ val_data: pd.DataFrame,
224
+ test_data: pd.DataFrame,
225
+ method: str
226
+ ) -> None:
227
+ """Save splitting information"""
228
+ n = len(full_data)
229
+
230
+ self.split_info = {
231
+ 'method': method,
232
+ 'strategy': self.split_strategy,
233
+ 'train_size': len(train_data),
234
+ 'val_size': len(val_data),
235
+ 'test_size': len(test_data),
236
+ 'train_percent': len(train_data) / n * 100,
237
+ 'val_percent': len(val_data) / n * 100,
238
+ 'test_percent': len(test_data) / n * 100,
239
+ 'total_samples': n,
240
+ 'features_count': len(full_data.columns),
241
+ 'split_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
242
+ }
243
+
244
+ # Add temporal period information if available
245
+ if isinstance(full_data.index, pd.DatetimeIndex):
246
+ self.split_info.update({
247
+ 'train_period': {
248
+ 'start': train_data.index.min().strftime('%Y-%m-%d'),
249
+ 'end': train_data.index.max().strftime('%Y-%m-%d')
250
+ },
251
+ 'val_period': {
252
+ 'start': val_data.index.min().strftime('%Y-%m-%d'),
253
+ 'end': val_data.index.max().strftime('%Y-%m-%d')
254
+ },
255
+ 'test_period': {
256
+ 'start': test_data.index.min().strftime('%Y-%m-%d'),
257
+ 'end': test_data.index.max().strftime('%Y-%m-%d')
258
+ }
259
+ })
260
+
261
+ # Save split indices
262
+ self.split_indices = {
263
+ 'train': train_data.index.tolist(),
264
+ 'val': val_data.index.tolist(),
265
+ 'test': test_data.index.tolist()
266
+ }
267
+
268
+ def _log_split_summary(
269
+ self,
270
+ train_data: pd.DataFrame,
271
+ val_data: pd.DataFrame,
272
+ test_data: pd.DataFrame
273
+ ) -> None:
274
+ """Log splitting summary"""
275
+ logger.info("✓ Data split completed:")
276
+ logger.info(f" Train: {len(train_data)} records ({self.split_info['train_percent']:.1f}%)")
277
+ logger.info(f" Validation: {len(val_data)} records ({self.split_info['val_percent']:.1f}%)")
278
+ logger.info(f" Test: {len(test_data)} records ({self.split_info['test_percent']:.1f}%)")
279
+
280
+ if 'train_period' in self.split_info:
281
+ logger.info(f"\nPeriods:")
282
+ logger.info(f" Train: {self.split_info['train_period']['start']} - {self.split_info['train_period']['end']}")
283
+ logger.info(f" Validation: {self.split_info['val_period']['start']} - {self.split_info['val_period']['end']}")
284
+ logger.info(f" Test: {self.split_info['test_period']['start']} - {self.split_info['test_period']['end']}")
285
+
286
+ # Target variable statistics
287
+ target = self.config.target_column
288
+ if target in train_data.columns:
289
+ logger.info(f"\nTarget variable '{target}' statistics:")
290
+ logger.info(f" Train: mean={train_data[target].mean():.2f}, std={train_data[target].std():.2f}")
291
+ logger.info(f" Validation: mean={val_data[target].mean():.2f}, std={val_data[target].std():.2f}")
292
+ logger.info(f" Test: mean={test_data[target].mean():.2f}, std={test_data[target].std():.2f}")
293
+
294
+ def _plot_data_split(
295
+ self,
296
+ full_data: pd.DataFrame,
297
+ train_data: pd.DataFrame,
298
+ val_data: pd.DataFrame,
299
+ test_data: pd.DataFrame
300
+ ) -> None:
301
+ """Visualise data splitting"""
302
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
303
+
304
+ target = self.config.target_column
305
+
306
+ # 1. Time series with set highlighting
307
+ if target in full_data.columns and isinstance(full_data.index, pd.DatetimeIndex):
308
+ axes[0, 0].plot(train_data.index, train_data[target],
309
+ label='Train', colour='blue', alpha=0.7, linewidth=1)
310
+ axes[0, 0].plot(val_data.index, val_data[target],
311
+ label='Validation', colour='orange', alpha=0.7, linewidth=1)
312
+ axes[0, 0].plot(test_data.index, test_data[target],
313
+ label='Test', colour='red', alpha=0.7, linewidth=1)
314
+
315
+ axes[0, 0].set_title(f'Data Split: {target}')
316
+ axes[0, 0].set_xlabel('Date')
317
+ axes[0, 0].set_ylabel(target)
318
+ axes[0, 0].legend()
319
+ axes[0, 0].grid(True, alpha=0.3)
320
+
321
+ # 2. Yearly distribution
322
+ if isinstance(full_data.index, pd.DatetimeIndex):
323
+ full_data['year'] = full_data.index.year
324
+ train_data['year'] = train_data.index.year
325
+ val_data['year'] = val_data.index.year
326
+ test_data['year'] = test_data.index.year
327
+
328
+ years = sorted(full_data['year'].unique())
329
+ train_counts = [len(train_data[train_data['year'] == year]) for year in years]
330
+ val_counts = [len(val_data[val_data['year'] == year]) for year in years]
331
+ test_counts = [len(test_data[test_data['year'] == year]) for year in years]
332
+
333
+ x = np.arange(len(years))
334
+ width = 0.25
335
+
336
+ axes[0, 1].bar(x - width, train_counts, width, label='Train', colour='blue', alpha=0.7)
337
+ axes[0, 1].bar(x, val_counts, width, label='Validation', colour='orange', alpha=0.7)
338
+ axes[0, 1].bar(x + width, test_counts, width, label='Test', colour='red', alpha=0.7)
339
+
340
+ axes[0, 1].set_title('Yearly Data Distribution')
341
+ axes[0, 1].set_xlabel('Year')
342
+ axes[0, 1].set_ylabel('Number of Records')
343
+ axes[0, 1].set_xticks(x)
344
+ axes[0, 1].set_xticklabels(years, rotation=45)
345
+ axes[0, 1].legend()
346
+ axes[0, 1].grid(True, alpha=0.3)
347
+
348
+ # Remove added columns
349
+ for df in [full_data, train_data, val_data, test_data]:
350
+ if 'year' in df.columns:
351
+ df.drop('year', axis=1, inplace=True)
352
+
353
+ # 3. Target variable distribution
354
+ if target in full_data.columns:
355
+ axes[1, 0].hist(train_data[target].dropna(), bins=30, alpha=0.5, label='Train', density=True)
356
+ axes[1, 0].hist(val_data[target].dropna(), bins=30, alpha=0.5, label='Validation', density=True)
357
+ axes[1, 0].hist(test_data[target].dropna(), bins=30, alpha=0.5, label='Test', density=True)
358
+
359
+ axes[1, 0].set_title(f'{target} Distribution Across Sets')
360
+ axes[1, 0].set_xlabel(target)
361
+ axes[1, 0].set_ylabel('Density')
362
+ axes[1, 0].legend()
363
+ axes[1, 0].grid(True, alpha=0.3)
364
+
365
+ # 4. Set statistics
366
+ if target in full_data.columns:
367
+ stats_data = []
368
+ for name, df in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
369
+ if target in df.columns:
370
+ stats_data.append({
371
+ 'Dataset': name,
372
+ 'Mean': df[target].mean(),
373
+ 'Std': df[target].std(),
374
+ 'Min': df[target].min(),
375
+ 'Max': df[target].max()
376
+ })
377
+
378
+ if stats_data:
379
+ stats_df = pd.DataFrame(stats_data)
380
+ stats_table = axes[1, 1].table(
381
+ cellText=stats_df.round(2).values,
382
+ colLabels=stats_df.columns,
383
+ cellLoc='center',
384
+ loc='center'
385
+ )
386
+ stats_table.auto_set_font_size(False)
387
+ stats_table.set_fontsize(9)
388
+ stats_table.scale(1, 1.5)
389
+ axes[1, 1].axis('off')
390
+ axes[1, 1].set_title('Set Statistics')
391
+
392
+ plt.suptitle(f'Data Splitting: {self.split_info["method"]} method', fontsize=14)
393
+ plt.tight_layout()
394
+ plt.savefig(
395
+ f'{self.config.results_dir}/plots/data_split.png',
396
+ dpi=300,
397
+ bbox_inches='tight'
398
+ )
399
+ plt.show()
400
+
401
+ def get_report(self) -> Dict:
402
+ """Get data splitting report"""
403
+ return self.split_info
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stationarity/__init__.py ADDED
File without changes
stationarity/stationarity_checker.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 6: STATIONARITY ANALYSIS
3
+ # ============================================
4
+ from typing import Dict, Optional
5
+ from venv import logger
6
+ from config.config import Config
7
+ import pandas as pd
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
11
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
12
+
13
+ class StationarityChecker:
14
+ """Class for checking time series stationarity"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise stationarity checker
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.test_results = {}
27
+ self.transformed_series = {}
28
+ self.best_transformation = {}
29
+
30
+ def check(
31
+ self,
32
+ data: pd.DataFrame,
33
+ target_col: Optional[str] = None,
34
+ make_stationary: bool = True,
35
+ try_transformations: bool = True
36
+ ) -> Dict:
37
+ """
38
+ Check time series stationarity
39
+
40
+ Parameters:
41
+ -----------
42
+ data : pd.DataFrame
43
+ Input data
44
+ target_col : str, optional
45
+ Target variable. If None, uses configuration default.
46
+ make_stationary : bool
47
+ Transform series to stationary form
48
+ try_transformations : bool
49
+ Try various transformations to achieve stationarity
50
+
51
+ Returns:
52
+ --------
53
+ Dict
54
+ Stationarity test results
55
+ """
56
+ logger.info("\n" + "="*80)
57
+ logger.info("STATIONARITY ANALYSIS")
58
+ logger.info("="*80)
59
+
60
+ target_col = target_col or self.config.target_column
61
+
62
+ if target_col not in data.columns:
63
+ logger.error(f"Target variable '{target_col}' not found")
64
+ return {}
65
+
66
+ series = data[target_col].dropna()
67
+
68
+ if len(series) < 10:
69
+ logger.warning("Insufficient data for stationarity analysis")
70
+ return {}
71
+
72
+ # Perform analysis
73
+ results = self._perform_stationarity_tests(series, target_col)
74
+
75
+ # Save results
76
+ self.test_results[target_col] = results
77
+
78
+ # Visualisation
79
+ if self.config.save_plots:
80
+ self._plot_stationarity_analysis(data, target_col, results)
81
+
82
+ # Log results
83
+ self._log_test_results(target_col, results)
84
+
85
+ # Transform to stationary form
86
+ if make_stationary and not results['overall']['is_stationary']:
87
+ if try_transformations:
88
+ transformed_data = self._make_stationary(data, target_col, results)
89
+ if transformed_data is not None:
90
+ data = transformed_data
91
+
92
+ return results
93
+
94
+
95
+ def _perform_stationarity_tests(
96
+ self,
97
+ series: pd.Series,
98
+ target_col: str
99
+ ) -> Dict:
100
+ """Perform various stationarity tests"""
101
+ results = {
102
+ 'adf': self._adf_test(series),
103
+ 'kpss': self._kpss_test(series),
104
+ 'pp': self._pp_test(series),
105
+ 'hurst': self._hurst_exponent(series),
106
+ 'variance_ratio': self._variance_ratio_test(series),
107
+ 'overall': {}
108
+ }
109
+
110
+ # Determine overall stationarity
111
+ adf_stationary = results['adf'].get('is_stationary', False)
112
+ kpss_stationary = results['kpss'].get('is_stationary', False)
113
+ pp_stationary = results['pp'].get('is_stationary', False)
114
+
115
+ # Stationarity determination logic
116
+ if adf_stationary and kpss_stationary:
117
+ overall_stationary = True
118
+ confidence = 'high'
119
+ elif adf_stationary and not kpss_stationary:
120
+ overall_stationary = True # ADF more reliable for detecting stationarity
121
+ confidence = 'medium'
122
+ elif not adf_stationary and kpss_stationary:
123
+ overall_stationary = False # KPSS indicates non-stationarity
124
+ confidence = 'medium'
125
+ else:
126
+ overall_stationary = False
127
+ confidence = 'high'
128
+
129
+ results['overall'] = {
130
+ 'is_stationary': overall_stationary,
131
+ 'confidence': confidence,
132
+ 'recommendation': self._get_stationarity_recommendation(results)
133
+ }
134
+
135
+ return results
136
+
137
+ def _adf_test(self, series: pd.Series) -> Dict:
138
+ """Augmented Dickey-Fuller (ADF) test"""
139
+ try:
140
+ adf_result = adfuller(series, autolag='AIC')
141
+
142
+ return {
143
+ 'statistic': float(adf_result[0]),
144
+ 'pvalue': float(adf_result[1]),
145
+ 'critical_values': {k: float(v) for k, v in adf_result[4].items()},
146
+ 'is_stationary': adf_result[1] < 0.05,
147
+ 'used_lag': int(adf_result[2]),
148
+ 'nobs': int(adf_result[3])
149
+ }
150
+ except Exception as e:
151
+ logger.warning(f"ADF test failed: {e}")
152
+ return {
153
+ 'statistic': np.nan,
154
+ 'pvalue': np.nan,
155
+ 'critical_values': {},
156
+ 'is_stationary': False,
157
+ 'error': str(e)
158
+ }
159
+
160
+ def _kpss_test(self, series: pd.Series) -> Dict:
161
+ """KPSS test"""
162
+ try:
163
+ kpss_result = kpss(series, regression='c', nlags='auto')
164
+
165
+ return {
166
+ 'statistic': float(kpss_result[0]),
167
+ 'pvalue': float(kpss_result[1]),
168
+ 'critical_values': {k: float(v) for k, v in kpss_result[3].items()},
169
+ 'is_stationary': kpss_result[1] > 0.05, # KPSS: p > 0.05 indicates stationarity
170
+ 'used_lag': int(kpss_result[2])
171
+ }
172
+ except Exception as e:
173
+ logger.warning(f"KPSS test failed: {e}")
174
+ return {
175
+ 'statistic': np.nan,
176
+ 'pvalue': np.nan,
177
+ 'critical_values': {},
178
+ 'is_stationary': False,
179
+ 'error': str(e)
180
+ }
181
+
182
+ def _pp_test(self, series: pd.Series) -> Dict:
183
+ """Phillips-Perron test"""
184
+ try:
185
+ # Simplified PP test version
186
+ from statsmodels.tsa.stattools import PhillipsPerron
187
+
188
+ pp_result = PhillipsPerron(series)
189
+
190
+ return {
191
+ 'statistic': float(pp_result.stat),
192
+ 'pvalue': float(pp_result.pvalue),
193
+ 'critical_values': pp_result.critical_values,
194
+ 'is_stationary': pp_result.pvalue < 0.05
195
+ }
196
+ except:
197
+ # If statsmodels with PP test not available
198
+ return {
199
+ 'statistic': np.nan,
200
+ 'pvalue': np.nan,
201
+ 'critical_values': {},
202
+ 'is_stationary': False,
203
+ 'note': 'Phillips-Perron test not available'
204
+ }
205
+
206
+ def _hurst_exponent(self, series: pd.Series) -> Dict:
207
+ """Calculate Hurst exponent"""
208
+ try:
209
+ # Simplified Hurst exponent calculation
210
+ lags = range(2, min(100, len(series)//4))
211
+ tau = []
212
+
213
+ for lag in lags:
214
+ # Split series into subsequences of length lag
215
+ n = len(series) // lag
216
+ if n < 2:
217
+ continue
218
+
219
+ subseries = [series[i*lag:(i+1)*lag] for i in range(n)]
220
+ # Calculate R/S for each subsequence
221
+ rs_values = []
222
+ for sub in subseries:
223
+ if len(sub) > 1:
224
+ mean = np.mean(sub)
225
+ deviations = sub - mean
226
+ z = np.cumsum(deviations)
227
+ r = np.max(z) - np.min(z)
228
+ s = np.std(sub)
229
+ if s > 0:
230
+ rs_values.append(r / s)
231
+
232
+ if rs_values:
233
+ tau.append(np.mean(rs_values))
234
+
235
+ if len(tau) > 2:
236
+ # Linear regression in log coordinates
237
+ x = np.log(lags[:len(tau)])
238
+ y = np.log(tau)
239
+
240
+ if len(x) > 1 and len(y) > 1:
241
+ slope = np.polyfit(x, y, 1)[0]
242
+
243
+ # Hurst exponent interpretation
244
+ if slope > 0.5:
245
+ trend_type = 'persistent'
246
+ elif slope < 0.5:
247
+ trend_type = 'anti-persistent'
248
+ else:
249
+ trend_type = 'random'
250
+
251
+ return {
252
+ 'exponent': float(slope),
253
+ 'trend_type': trend_type,
254
+ 'interpretation': self._interpret_hurst(slope)
255
+ }
256
+
257
+ return {
258
+ 'exponent': np.nan,
259
+ 'trend_type': 'unknown',
260
+ 'interpretation': 'Insufficient data'
261
+ }
262
+
263
+ except Exception as e:
264
+ logger.debug(f"Hurst exponent not calculated: {e}")
265
+ return {
266
+ 'exponent': np.nan,
267
+ 'trend_type': 'unknown',
268
+ 'error': str(e)
269
+ }
270
+
271
+ def _interpret_hurst(self, hurst_exponent: float) -> str:
272
+ """Interpret Hurst exponent"""
273
+ if hurst_exponent > 0.75:
274
+ return "Strong persistence (long-term memory)"
275
+ elif hurst_exponent > 0.6:
276
+ return "Moderate persistence"
277
+ elif hurst_exponent > 0.4:
278
+ return "Weak persistence / random walk"
279
+ elif hurst_exponent > 0.25:
280
+ return "Weak anti-persistence"
281
+ else:
282
+ return "Strong anti-persistence (frequent trend reversal)"
283
+
284
+ def _variance_ratio_test(self, series: pd.Series) -> Dict:
285
+ """Variance Ratio test for random walk"""
286
+ try:
287
+ # Simplified variance ratio test
288
+ if len(series) < 20:
289
+ return {'ratio': np.nan, 'is_random_walk': False}
290
+
291
+ # Calculate differences
292
+ diff1 = series.diff(1).dropna()
293
+ diff2 = series.diff(2).dropna()[1:] # Shift to align indices
294
+
295
+ if len(diff1) < 5 or len(diff2) < 5:
296
+ return {'ratio': np.nan, 'is_random_walk': False}
297
+
298
+ var1 = np.var(diff1)
299
+ var2 = np.var(diff2)
300
+
301
+ if var1 > 0:
302
+ ratio = var2 / (2 * var1)
303
+
304
+ # For random walk ratio ≈ 1
305
+ is_random_walk = 0.8 < ratio < 1.2
306
+
307
+ return {
308
+ 'ratio': float(ratio),
309
+ 'is_random_walk': bool(is_random_walk),
310
+ 'var_diff1': float(var1),
311
+ 'var_diff2': float(var2)
312
+ }
313
+ else:
314
+ return {'ratio': np.nan, 'is_random_walk': False}
315
+
316
+ except Exception as e:
317
+ logger.debug(f"Variance ratio test failed: {e}")
318
+ return {'ratio': np.nan, 'is_random_walk': False, 'error': str(e)}
319
+
320
+ def _get_stationarity_recommendation(self, results: Dict) -> str:
321
+ """Get stationarity recommendations"""
322
+ # Check for keys before access
323
+ if 'overall' not in results or 'is_stationary' not in results['overall']:
324
+ return "Could not determine stationarity. Check data and test settings."
325
+
326
+ if results['overall']['is_stationary']:
327
+ return "Series is stationary, suitable for modelling"
328
+ else:
329
+ recommendations = []
330
+
331
+ # Check Hurst test results
332
+ if 'hurst' in results and 'exponent' in results['hurst']:
333
+ hurst_exponent = results['hurst']['exponent']
334
+ if not np.isnan(hurst_exponent) and hurst_exponent > 0.6:
335
+ recommendations.append("Apply differencing to remove trend")
336
+
337
+ # Check ADF test
338
+ if 'adf' in results and 'pvalue' in results['adf']:
339
+ adf_pvalue = results['adf']['pvalue']
340
+ if not np.isnan(adf_pvalue) and adf_pvalue > 0.1:
341
+ recommendations.append("Consider seasonal differencing due to non-stationarity")
342
+
343
+ if len(recommendations) == 0:
344
+ recommendations.append("Try logarithmic transformation and differencing")
345
+
346
+ return "; ".join(recommendations)
347
+
348
+ def _plot_stationarity_analysis(
349
+ self,
350
+ data: pd.DataFrame,
351
+ target_col: str,
352
+ results: Dict
353
+ ) -> None:
354
+ """Visualise stationarity analysis"""
355
+ series = data[target_col]
356
+
357
+ fig, axes = plt.subplots(2, 3, figsize=(16, 10))
358
+
359
+ # 1. Original series
360
+ axes[0, 0].plot(series.index, series, linewidth=1)
361
+ axes[0, 0].set_title(f'Original Time Series: {target_col}')
362
+ axes[0, 0].set_xlabel('Date')
363
+ axes[0, 0].set_ylabel(target_col)
364
+ axes[0, 0].grid(True, alpha=0.3)
365
+
366
+ # 2. Rolling statistics
367
+ rolling_mean = series.rolling(window=365, center=True, min_periods=1).mean()
368
+ rolling_std = series.rolling(window=365, center=True, min_periods=1).std()
369
+
370
+ axes[0, 1].plot(series.index, series, label='Original series', alpha=0.7, linewidth=0.5)
371
+ axes[0, 1].plot(rolling_mean.index, rolling_mean, label='Rolling mean (365)', color='red', linewidth=2)
372
+ axes[0, 1].plot(rolling_std.index, rolling_std, label='Rolling STD (365)', color='green', linewidth=2)
373
+ axes[0, 1].set_title(f'Rolling Statistics: {target_col}')
374
+ axes[0, 1].set_xlabel('Date')
375
+ axes[0, 1].set_ylabel(target_col)
376
+ axes[0, 1].legend(fontsize=8)
377
+ axes[0, 1].grid(True, alpha=0.3)
378
+
379
+ # 3. ACF
380
+ plot_acf(series.dropna(), lags=50, ax=axes[0, 2], alpha=0.05)
381
+ axes[0, 2].set_title(f'Autocorrelation Function (ACF): {target_col}')
382
+ axes[0, 2].set_xlabel('Lag')
383
+ axes[0, 2].set_ylabel('Autocorrelation')
384
+ axes[0, 2].grid(True, alpha=0.3)
385
+
386
+ # 4. PACF
387
+ plot_pacf(series.dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
388
+ axes[1, 0].set_title(f'Partial Autocorrelation Function (PACF): {target_col}')
389
+ axes[1, 0].set_xlabel('Lag')
390
+ axes[1, 0].set_ylabel('Partial Autocorrelation')
391
+ axes[1, 0].grid(True, alpha=0.3)
392
+
393
+ # 5. Histogram and Q-Q plot
394
+ axes[1, 1].hist(series.dropna(), bins=30, edgecolor='black', alpha=0.7, density=True)
395
+ axes[1, 1].set_title(f'Distribution: {target_col}')
396
+ axes[1, 1].set_xlabel('Value')
397
+ axes[1, 1].set_ylabel('Density')
398
+ axes[1, 1].grid(True, alpha=0.3)
399
+
400
+ # 6. Series differences
401
+ diff1 = series.diff(1).dropna()
402
+ axes[1, 2].plot(diff1.index, diff1, linewidth=0.5)
403
+ axes[1, 2].set_title(f'First Difference: {target_col}')
404
+ axes[1, 2].set_xlabel('Date')
405
+ axes[1, 2].set_ylabel(f'Δ{target_col}')
406
+ axes[1, 2].grid(True, alpha=0.3)
407
+
408
+ plt.suptitle(
409
+ f'Stationarity Analysis: {target_col}\n'
410
+ f'Stationary: {"✓ Yes" if results["overall"]["is_stationary"] else "✗ No"} '
411
+ f'(confidence: {results["overall"]["confidence"]})',
412
+ fontsize=14
413
+ )
414
+
415
+ plt.tight_layout()
416
+ plt.savefig(
417
+ f'{self.config.results_dir}/plots/stationarity_{target_col}.png',
418
+ dpi=300,
419
+ bbox_inches='tight'
420
+ )
421
+ plt.show()
422
+
423
+ def _log_test_results(self, target_col: str, results: Dict) -> None:
424
+ """Log test results"""
425
+ logger.info("\nSTATIONARITY TEST RESULTS:")
426
+ logger.info("-" * 50)
427
+
428
+ # ADF test
429
+ adf = results['adf']
430
+ logger.info(f"Augmented Dickey-Fuller (ADF) test:")
431
+ logger.info(f" Statistic: {adf['statistic']:.4f}")
432
+ logger.info(f" p-value: {adf['pvalue']:.4f}")
433
+ logger.info(f" Stationary: {'✓ Yes' if adf['is_stationary'] else '✗ No'}")
434
+
435
+ # KPSS test
436
+ kpss_test = results['kpss']
437
+ if 'statistic' in kpss_test and not np.isnan(kpss_test['statistic']):
438
+ logger.info(f"\nKPSS test:")
439
+ logger.info(f" Statistic: {kpss_test['statistic']:.4f}")
440
+ logger.info(f" p-value: {kpss_test['pvalue']:.4f}")
441
+ logger.info(f" Stationary: {'✓ Yes' if kpss_test['is_stationary'] else '✗ No'}")
442
+
443
+ # Hurst exponent
444
+ hurst = results['hurst']
445
+ if 'exponent' in hurst and not np.isnan(hurst['exponent']):
446
+ logger.info(f"\nHurst exponent:")
447
+ logger.info(f" Value: {hurst['exponent']:.3f}")
448
+ logger.info(f" Trend type: {hurst['trend_type']}")
449
+ logger.info(f" Interpretation: {hurst.get('interpretation', '')}")
450
+
451
+ # Overall interpretation
452
+ logger.info(f"\nOVERALL CONCLUSION:")
453
+ logger.info("-" * 30)
454
+ logger.info(f"Stationary: {'✓ Yes' if results['overall']['is_stationary'] else '✗ No'}")
455
+ logger.info(f"Confidence: {results['overall']['confidence']}")
456
+ logger.info(f"Recommendation: {results['overall']['recommendation']}")
457
+
458
+ def _make_stationary(
459
+ self,
460
+ data: pd.DataFrame,
461
+ target_col: str,
462
+ results: Dict
463
+ ) -> Optional[pd.DataFrame]:
464
+ """
465
+ Transform series to stationary form
466
+
467
+ Parameters:
468
+ -----------
469
+ data : pd.DataFrame
470
+ Input data
471
+ target_col : str
472
+ Target variable
473
+ results : Dict
474
+ Stationarity test results
475
+
476
+ Returns:
477
+ --------
478
+ Optional[pd.DataFrame]
479
+ Data with stationary series or None if transformation failed
480
+ """
481
+ logger.info("\nTRANSFORMING TO STATIONARY FORM:")
482
+ logger.info("-" * 40)
483
+
484
+ data_processed = data.copy()
485
+ series = data_processed[target_col]
486
+
487
+ # Stationarisation methods in order of preference
488
+ methods = [
489
+ ('diff', 'first-order differencing'),
490
+ ('seasonal_diff', f'seasonal differencing (period={self.config.seasonal_period})'),
491
+ ('log_diff', 'logarithmic differencing'),
492
+ ('boxcox_diff', 'Box-Cox + differencing'),
493
+ ('detrend', 'detrending'),
494
+ ('combination', 'combined method')
495
+ ]
496
+
497
+ best_method = None
498
+ best_series = None
499
+ best_pvalue = 1.0
500
+ best_stationary = False
501
+
502
+ for method, method_name in methods:
503
+ try:
504
+ if method == 'diff':
505
+ # Simple differencing
506
+ transformed = series.diff(1).dropna()
507
+ test_series = transformed
508
+
509
+ elif method == 'seasonal_diff':
510
+ # Seasonal differencing
511
+ transformed = series.diff(self.config.seasonal_period).dropna()
512
+ test_series = transformed
513
+
514
+ elif method == 'log_diff':
515
+ # Logarithmic differencing
516
+ if (series > 0).all():
517
+ log_series = np.log(series)
518
+ transformed = log_series.diff(1).dropna()
519
+ test_series = transformed
520
+ else:
521
+ # Shift for negative values
522
+ shift = abs(series.min()) + 1 if series.min() <= 0 else 0
523
+ log_series = np.log(series + shift)
524
+ transformed = log_series.diff(1).dropna()
525
+ test_series = transformed
526
+
527
+ elif method == 'boxcox_diff':
528
+ # Box-Cox transformation + differencing
529
+ try:
530
+ from scipy.stats import boxcox
531
+ # Add constant for positive values
532
+ shift = abs(series.min()) + 1 if series.min() <= 0 else 0
533
+ boxcox_series, _ = boxcox(series + shift)
534
+ transformed = pd.Series(boxcox_series, index=series.index).diff(1).dropna()
535
+ test_series = transformed
536
+ except:
537
+ continue
538
+
539
+ elif method == 'detrend':
540
+ # Linear detrending
541
+ x = np.arange(len(series))
542
+ y = series.values
543
+ coeffs = np.polyfit(x, y, 1)
544
+ trend = np.polyval(coeffs, x)
545
+ transformed = pd.Series(y - trend, index=series.index)
546
+ test_series = transformed
547
+
548
+ elif method == 'combination':
549
+ # Combined method: log + differencing + detrending
550
+ if (series > 0).all():
551
+ log_series = np.log(series)
552
+ diff_series = log_series.diff(1)
553
+
554
+ # Detrending residuals
555
+ x = np.arange(len(diff_series))
556
+ y = diff_series.values
557
+ valid_mask = ~np.isnan(y)
558
+
559
+ if valid_mask.sum() > 2:
560
+ coeffs = np.polyfit(x[valid_mask], y[valid_mask], 1)
561
+ trend = np.polyval(coeffs, x)
562
+ transformed = pd.Series(y - trend, index=series.index)
563
+ test_series = transformed.dropna()
564
+ else:
565
+ test_series = diff_series.dropna()
566
+ else:
567
+ continue
568
+
569
+ # Check stationarity after transformation
570
+ if len(test_series) > 10:
571
+ adf_result = adfuller(test_series.dropna())
572
+ is_stationary = adf_result[1] < 0.05
573
+ pvalue = adf_result[1]
574
+
575
+ logger.info(f" Method: {method_name}")
576
+ logger.info(f" ADF p-value: {pvalue:.4f}")
577
+ logger.info(f" Stationary: {'✓ Yes' if is_stationary else '✗ No'}")
578
+
579
+ # Save best method
580
+ if is_stationary and pvalue < best_pvalue:
581
+ best_pvalue = pvalue
582
+ best_method = method
583
+ best_series = transformed
584
+ best_stationary = True
585
+
586
+ if pvalue < 0.01: # Very good result
587
+ break
588
+
589
+ except Exception as e:
590
+ logger.debug(f" Method {method} failed: {e}")
591
+ continue
592
+
593
+ # Save results
594
+ if best_series is not None:
595
+ new_col_name = f'{target_col}_stationary_{best_method}'
596
+
597
+ # Align indices
598
+ aligned_series = pd.Series(
599
+ best_series.values,
600
+ index=data_processed.index[-len(best_series):]
601
+ )
602
+
603
+ data_processed[new_col_name] = aligned_series
604
+
605
+ self.transformed_series[target_col] = {
606
+ 'method': best_method,
607
+ 'new_column': new_col_name,
608
+ 'pvalue': float(best_pvalue),
609
+ 'is_stationary': best_stationary,
610
+ 'original_shape': len(series),
611
+ 'transformed_shape': len(best_series)
612
+ }
613
+
614
+ self.best_transformation[target_col] = best_method
615
+
616
+ logger.info(f"\n✓ Selected method: {best_method}")
617
+ logger.info(f" Saved as '{new_col_name}'")
618
+ logger.info(f" p-value: {best_pvalue:.4f}")
619
+
620
+ return data_processed
621
+ else:
622
+ logger.warning("✗ Could not find suitable transformation for stationarisation")
623
+ return None
624
+
625
+ def get_report(self) -> Dict:
626
+ """Get stationarity report"""
627
+ return {
628
+ 'test_results': self.test_results,
629
+ 'transformed_series': self.transformed_series,
630
+ 'best_transformations': self.best_transformation
631
+ }
streamlit/streamlit_app.py ADDED
The diff for this file is too large to render. See raw diff
 
temp_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
validation/__init__.py ADDED
File without changes
validation/data_validator.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 12: DATA VALIDATION
3
+ # ============================================
4
+ from datetime import datetime
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Dict, List
8
+ from venv import logger
9
+
10
+ from config.config import Config
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+ class DataValidator:
15
+ """Class for data quality validation"""
16
+
17
+ def __init__(self, config: Config):
18
+ """
19
+ Initialise data validator
20
+
21
+ Parameters:
22
+ -----------
23
+ config : Config
24
+ Experiment configuration
25
+ """
26
+ self.config = config
27
+ self.validation_results = {}
28
+ self.quality_metrics = {}
29
+ self.issues_found = {}
30
+
31
+ def validate(
32
+ self,
33
+ data: pd.DataFrame,
34
+ stage: str = 'final',
35
+ rules: Dict = None,
36
+ detailed: bool = True
37
+ ) -> Dict:
38
+ """
39
+ Validate data quality
40
+
41
+ Parameters:
42
+ -----------
43
+ data : pd.DataFrame
44
+ Input data
45
+ stage : str
46
+ Validation stage: 'raw', 'processed', 'final'
47
+ rules : Dict, optional
48
+ Validation rules. If None, uses configuration defaults.
49
+ detailed : bool
50
+ Whether to perform detailed validation
51
+
52
+ Returns:
53
+ --------
54
+ Dict
55
+ Validation results
56
+ """
57
+ logger.info("\n" + "="*80)
58
+ logger.info(f"DATA VALIDATION ({stage.upper()})")
59
+ logger.info("="*80)
60
+
61
+ rules = rules or self.config.validation_rules
62
+
63
+ validation_results = {
64
+ 'stage': stage,
65
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
66
+ 'data_shape': list(data.shape),
67
+ 'basic_checks': {},
68
+ 'quality_metrics': {},
69
+ 'issues': {},
70
+ 'recommendations': [],
71
+ 'overall_score': 0,
72
+ 'status': 'PASS'
73
+ }
74
+
75
+ # Basic checks
76
+ validation_results['basic_checks'] = self._basic_checks(data, rules)
77
+
78
+ # Quality checks
79
+ validation_results['quality_metrics'] = self._quality_metrics(data, rules)
80
+
81
+ # Problem detection
82
+ if detailed:
83
+ validation_results['issues'] = self._find_issues(data, rules)
84
+
85
+ # Recommendation generation
86
+ validation_results['recommendations'] = self._generate_recommendations(
87
+ validation_results['basic_checks'],
88
+ validation_results['quality_metrics'],
89
+ validation_results['issues']
90
+ )
91
+
92
+ # Overall score calculation
93
+ validation_results['overall_score'] = self._calculate_overall_score(validation_results)
94
+
95
+ # Status determination
96
+ if validation_results['overall_score'] >= 80:
97
+ validation_results['status'] = 'PASS'
98
+ elif validation_results['overall_score'] >= 60:
99
+ validation_results['status'] = 'WARNING'
100
+ else:
101
+ validation_results['status'] = 'FAIL'
102
+
103
+ # Save results
104
+ self.validation_results[stage] = validation_results
105
+ self.quality_metrics[stage] = validation_results['quality_metrics']
106
+
107
+ # Log results
108
+ self._log_validation_results(validation_results)
109
+
110
+ return validation_results
111
+
112
+ def _basic_checks(self, data: pd.DataFrame, rules: Dict) -> Dict:
113
+ """Basic data checks"""
114
+ checks = {}
115
+
116
+ # 1. Data size check
117
+ checks['min_rows'] = {
118
+ 'value': len(data),
119
+ 'threshold': rules.get('min_rows', 100),
120
+ 'passed': len(data) >= rules.get('min_rows', 100)
121
+ }
122
+
123
+ # 2. Target variable presence check
124
+ target = self.config.target_column
125
+ checks['has_target'] = {
126
+ 'value': target in data.columns,
127
+ 'passed': target in data.columns
128
+ }
129
+
130
+ # 3. Missing values check
131
+ missing_percentage = (data.isnull().sum().sum() / data.size) * 100
132
+ checks['missing_percentage'] = {
133
+ 'value': missing_percentage,
134
+ 'threshold': rules.get('max_missing_percentage', 30),
135
+ 'passed': missing_percentage <= rules.get('max_missing_percentage', 30)
136
+ }
137
+
138
+ # 4. Duplicates check
139
+ duplicate_count = data.duplicated().sum()
140
+ duplicate_percentage = (duplicate_count / len(data)) * 100
141
+ checks['duplicates'] = {
142
+ 'value': duplicate_percentage,
143
+ 'threshold': 5, # Maximum 5% duplicates
144
+ 'passed': duplicate_percentage <= 5
145
+ }
146
+
147
+ # 5. Data types check
148
+ numeric_count = len(data.select_dtypes(include=[np.number]).columns)
149
+ checks['numeric_features'] = {
150
+ 'value': numeric_count,
151
+ 'passed': numeric_count >= 1 # At least one numeric feature required
152
+ }
153
+
154
+ return checks
155
+
156
+ def _quality_metrics(self, data: pd.DataFrame, rules: Dict) -> Dict:
157
+ """Data quality metrics"""
158
+ metrics = {}
159
+
160
+ # 1. Numeric features statistics
161
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
162
+
163
+ if len(numeric_cols) > 0:
164
+ numeric_stats = {}
165
+ for col in numeric_cols:
166
+ col_data = data[col].dropna()
167
+ if len(col_data) > 0:
168
+ numeric_stats[col] = {
169
+ 'mean': float(col_data.mean()),
170
+ 'std': float(col_data.std()),
171
+ 'skewness': float(col_data.skew()),
172
+ 'kurtosis': float(col_data.kurtosis()),
173
+ 'zeros_percentage': float((col_data == 0).sum() / len(col_data) * 100),
174
+ 'unique_percentage': float(col_data.nunique() / len(col_data) * 100)
175
+ }
176
+
177
+ metrics['numeric_statistics'] = numeric_stats
178
+
179
+ # 2. Data stability (for time series)
180
+ if isinstance(data.index, pd.DatetimeIndex):
181
+ stability_metrics = self._calculate_temporal_stability(data)
182
+ metrics['temporal_stability'] = stability_metrics
183
+
184
+ # 3. Feature informativeness
185
+ if self.config.target_column in data.columns:
186
+ informativeness = self._calculate_feature_informativeness(data)
187
+ metrics['feature_informativeness'] = informativeness
188
+
189
+ # 4. Target variable quality
190
+ target = self.config.target_column
191
+ if target in data.columns:
192
+ target_data = data[target].dropna()
193
+ if len(target_data) > 0:
194
+ target_metrics = {
195
+ 'missing_percentage': float(target_data.isnull().sum() / len(data) * 100),
196
+ 'unique_values': int(target_data.nunique()),
197
+ 'is_constant': bool(target_data.nunique() <= 1),
198
+ 'has_outliers': self._check_target_outliers(target_data),
199
+ 'distribution_type': self._identify_distribution(target_data)
200
+ }
201
+ metrics['target_quality'] = target_metrics
202
+
203
+ # 5. Class balance (for classification) - not applicable here, but kept as placeholder
204
+ metrics['class_balance'] = {'note': 'Not applicable for regression'}
205
+
206
+ return metrics
207
+
208
+ def _calculate_temporal_stability(self, data: pd.DataFrame) -> Dict:
209
+ """Calculate time series stability metrics"""
210
+ stability = {}
211
+
212
+ if not isinstance(data.index, pd.DatetimeIndex):
213
+ return stability
214
+
215
+ # Split into periods (e.g., by years)
216
+ if 'year' not in data.columns:
217
+ data_copy = data.copy()
218
+ data_copy['year'] = data_copy.index.year
219
+ else:
220
+ data_copy = data
221
+
222
+ years = sorted(data_copy['year'].unique())
223
+
224
+ if len(years) > 1:
225
+ # Statistics by years for numeric columns
226
+ year_stats = {}
227
+ for col in data.select_dtypes(include=[np.number]).columns[:5]: # First 5 columns
228
+ yearly_means = data_copy.groupby('year')[col].mean()
229
+ yearly_stds = data_copy.groupby('year')[col].std()
230
+
231
+ # Coefficient of variation between years
232
+ if yearly_means.std() > 0:
233
+ cv_between_years = yearly_means.std() / yearly_means.mean()
234
+ else:
235
+ cv_between_years = 0
236
+
237
+ year_stats[col] = {
238
+ 'yearly_means': yearly_means.to_dict(),
239
+ 'yearly_stds': yearly_stds.to_dict(),
240
+ 'cv_between_years': float(cv_between_years),
241
+ 'mean_stability': float(1 - cv_between_years) # 1 - CV, closer to 1 means more stable
242
+ }
243
+
244
+ stability['yearly_statistics'] = year_stats
245
+
246
+ # Check for time gaps
247
+ time_diff = pd.Series(data.index).diff().dropna()
248
+ if len(time_diff) > 0:
249
+ max_gap = time_diff.max()
250
+ avg_gap = time_diff.mean()
251
+ gap_std = time_diff.std()
252
+
253
+ stability['time_gaps'] = {
254
+ 'max_gap_days': float(max_gap.days if hasattr(max_gap, 'days') else max_gap),
255
+ 'avg_gap_days': float(avg_gap.days if hasattr(avg_gap, 'days') else avg_gap),
256
+ 'gap_std': float(gap_std.days if hasattr(gap_std, 'days') else gap_std),
257
+ 'has_irregular_gaps': gap_std > avg_gap * 0.5 # If standard deviation > 50% of mean
258
+ }
259
+
260
+ # Seasonal stability
261
+ if len(data) > 365:
262
+ try:
263
+ # Analyse seasonal patterns
264
+ seasonal_stability = self._analyse_seasonal_stability(data)
265
+ stability['seasonal_stability'] = seasonal_stability
266
+ except:
267
+ pass
268
+
269
+ return stability
270
+
271
+ def _analyse_seasonal_stability(self, data: pd.DataFrame) -> Dict:
272
+ """Analyse seasonal patterns stability"""
273
+ if not isinstance(data.index, pd.DatetimeIndex):
274
+ return {}
275
+
276
+ # For simplicity, analyse only target variable
277
+ target = self.config.target_column
278
+ if target not in data.columns:
279
+ return {}
280
+
281
+ series = data[target]
282
+
283
+ # Split by years and compare seasonal patterns
284
+ data_copy = data.copy()
285
+ data_copy['year'] = data_copy.index.year
286
+ data_copy['month'] = data_copy.index.month
287
+
288
+ if 'year' in data_copy.columns and 'month' in data_copy.columns:
289
+ monthly_means = data_copy.groupby(['year', 'month'])[target].mean().unstack()
290
+
291
+ if not monthly_means.empty:
292
+ # Correlation between years
293
+ yearly_corr = monthly_means.corr().mean().mean()
294
+
295
+ # Variation between years
296
+ monthly_cv = monthly_means.std() / monthly_means.mean()
297
+ avg_monthly_cv = monthly_cv.mean()
298
+
299
+ return {
300
+ 'yearly_correlation': float(yearly_corr),
301
+ 'average_monthly_cv': float(avg_monthly_cv),
302
+ 'seasonal_consistency': 'high' if yearly_corr > 0.8 and avg_monthly_cv < 0.3 else
303
+ 'medium' if yearly_corr > 0.6 else 'low'
304
+ }
305
+
306
+ return {}
307
+
308
+ def _calculate_feature_informativeness(self, data: pd.DataFrame) -> Dict:
309
+ """Calculate feature informativeness"""
310
+ informativeness = {}
311
+
312
+ target = self.config.target_column
313
+ if target not in data.columns:
314
+ return informativeness
315
+
316
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
317
+ numeric_cols = [col for col in numeric_cols if col != target]
318
+
319
+ for col in numeric_cols[:20]: # Limit number of features for analysis
320
+ try:
321
+ # Correlation with target variable
322
+ correlation = data[col].corr(data[target])
323
+
324
+ # Mutual information (approximated)
325
+ # For simplicity, use absolute correlation as informativeness measure
326
+ informativeness[col] = {
327
+ 'correlation_with_target': float(correlation),
328
+ 'abs_correlation': float(abs(correlation)),
329
+ 'informativeness': 'high' if abs(correlation) > 0.5 else
330
+ 'medium' if abs(correlation) > 0.3 else 'low'
331
+ }
332
+ except:
333
+ continue
334
+
335
+ return informativeness
336
+
337
+ def _check_target_outliers(self, target_series: pd.Series) -> Dict:
338
+ """Check target variable for outliers"""
339
+ if len(target_series) < 10:
340
+ return {'has_outliers': False, 'outlier_percentage': 0}
341
+
342
+ q1 = target_series.quantile(0.25)
343
+ q3 = target_series.quantile(0.75)
344
+ iqr = q3 - q1
345
+
346
+ if iqr > 0:
347
+ lower_bound = q1 - 1.5 * iqr
348
+ upper_bound = q3 + 1.5 * iqr
349
+
350
+ outliers = target_series[(target_series < lower_bound) | (target_series > upper_bound)]
351
+ outlier_percentage = len(outliers) / len(target_series) * 100
352
+
353
+ return {
354
+ 'has_outliers': len(outliers) > 0,
355
+ 'outlier_count': int(len(outliers)),
356
+ 'outlier_percentage': float(outlier_percentage),
357
+ 'outlier_bounds': {'lower': float(lower_bound), 'upper': float(upper_bound)}
358
+ }
359
+
360
+ return {'has_outliers': False, 'outlier_percentage': 0}
361
+
362
+ def _identify_distribution(self, series: pd.Series) -> str:
363
+ """Identify distribution type"""
364
+ if len(series) < 30:
365
+ return 'insufficient_data'
366
+
367
+ skewness = series.skew()
368
+ kurtosis = series.kurtosis()
369
+
370
+ if abs(skewness) < 0.5 and abs(kurtosis) < 1:
371
+ return 'normal_like'
372
+ elif skewness > 1:
373
+ return 'right_skewed'
374
+ elif skewness < -1:
375
+ return 'left_skewed'
376
+ elif kurtosis > 3:
377
+ return 'heavy_tailed'
378
+ elif kurtosis < 2:
379
+ return 'light_tailed'
380
+ else:
381
+ return 'unknown'
382
+
383
+ def _find_issues(self, data: pd.DataFrame, rules: Dict) -> Dict:
384
+ """Find data problems"""
385
+ issues = {
386
+ 'critical': [],
387
+ 'warning': [],
388
+ 'info': []
389
+ }
390
+
391
+ # 1. Check missing values in important features
392
+ missing_info = data.isnull().sum()
393
+ high_missing_cols = missing_info[missing_info / len(data) * 100 > 20].index.tolist()
394
+
395
+ for col in high_missing_cols:
396
+ missing_pct = missing_info[col] / len(data) * 100
397
+ if missing_pct > 50:
398
+ issues['critical'].append(f"Column '{col}': {missing_pct:.1f}% missing values (critical)")
399
+ elif missing_pct > 20:
400
+ issues['warning'].append(f"Column '{col}': {missing_pct:.1f}% missing values")
401
+
402
+ # 2. Check constant features
403
+ for col in data.columns:
404
+ if data[col].nunique() <= 1:
405
+ issues['critical'].append(f"Column '{col}': constant value")
406
+
407
+ # 3. Check feature correlation with itself (lags)
408
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
409
+ for col in numeric_cols:
410
+ if '_lag_' in col or '_diff_' in col:
411
+ base_col = col.split('_lag_')[0] if '_lag_' in col else col.split('_diff_')[0]
412
+ if base_col in numeric_cols:
413
+ correlation = data[col].corr(data[base_col])
414
+ if pd.notna(correlation) and abs(correlation) > 0.95:
415
+ issues['info'].append(f"Column '{col}': high correlation with '{base_col}' ({correlation:.3f})")
416
+
417
+ # 4. Check time gaps
418
+ if isinstance(data.index, pd.DatetimeIndex):
419
+ time_diff = pd.Series(data.index).diff().dropna()
420
+ if len(time_diff) > 0:
421
+ max_gap = time_diff.max()
422
+ if hasattr(max_gap, 'days') and max_gap.days > 30:
423
+ issues['warning'].append(f"Detected time gap: {max_gap.days} days")
424
+
425
+ # 5. Check target variable
426
+ target = self.config.target_column
427
+ if target in data.columns:
428
+ target_data = data[target].dropna()
429
+ if len(target_data) > 0:
430
+ if target_data.nunique() <= 1:
431
+ issues['critical'].append(f"Target variable '{target}': constant value")
432
+
433
+ # Check for outliers
434
+ outlier_check = self._check_target_outliers(target_data)
435
+ if outlier_check.get('has_outliers', False) and outlier_check.get('outlier_percentage', 0) > 10:
436
+ issues['warning'].append(f"Target variable '{target}': {outlier_check['outlier_percentage']:.1f}% outliers")
437
+
438
+ # 6. Check multicollinearity (simplified)
439
+ if len(numeric_cols) > 5:
440
+ corr_matrix = data[numeric_cols].corr().abs()
441
+ high_corr_pairs = []
442
+
443
+ for i in range(len(corr_matrix.columns)):
444
+ for j in range(i+1, len(corr_matrix.columns)):
445
+ if corr_matrix.iloc[i, j] > 0.9:
446
+ col1 = corr_matrix.columns[i]
447
+ col2 = corr_matrix.columns[j]
448
+ high_corr_pairs.append((col1, col2, corr_matrix.iloc[i, j]))
449
+
450
+ if len(high_corr_pairs) > 5:
451
+ issues['warning'].append(f"Detected multicollinearity: {len(high_corr_pairs)} pairs with correlation > 0.9")
452
+
453
+ return issues
454
+
455
+ def _generate_recommendations(
456
+ self,
457
+ basic_checks: Dict,
458
+ quality_metrics: Dict,
459
+ issues: Dict
460
+ ) -> List[str]:
461
+ """Generate data improvement recommendations"""
462
+ recommendations = []
463
+
464
+ # Recommendations based on basic checks
465
+ for check_name, check_info in basic_checks.items():
466
+ if not check_info.get('passed', True):
467
+ if check_name == 'min_rows':
468
+ recommendations.append(f"Increase data volume: current row count ({check_info['value']}) below minimum threshold ({check_info['threshold']})")
469
+ elif check_name == 'has_target':
470
+ recommendations.append(f"Add target variable '{self.config.target_column}' to data")
471
+ elif check_name == 'missing_percentage':
472
+ recommendations.append(f"Handle missing values: {check_info['value']:.1f}% missing exceeds threshold {check_info['threshold']}%")
473
+ elif check_name == 'duplicates':
474
+ recommendations.append(f"Remove duplicates: {check_info['value']:.1f}% duplicate rows")
475
+
476
+ # Recommendations based on issues
477
+ if issues.get('critical'):
478
+ recommendations.append("Resolve critical issues before using data")
479
+
480
+ if issues.get('warning'):
481
+ recommendations.append("Consider addressing warnings to improve data quality")
482
+
483
+ # Recommendations based on quality metrics
484
+ target_metrics = quality_metrics.get('target_quality', {})
485
+ if target_metrics.get('is_constant', False):
486
+ recommendations.append(f"Target variable '{self.config.target_column}' is constant, different target variable needed")
487
+
488
+ if target_metrics.get('has_outliers', {}).get('has_outliers', False):
489
+ outlier_pct = target_metrics['has_outliers'].get('outlier_percentage', 0)
490
+ if outlier_pct > 5:
491
+ recommendations.append(f"Handle outliers in target variable: {outlier_pct:.1f}% outliers")
492
+
493
+ # Time series stability recommendations
494
+ temporal_stability = quality_metrics.get('temporal_stability', {})
495
+ if temporal_stability.get('time_gaps', {}).get('has_irregular_gaps', False):
496
+ recommendations.append("Detected irregular time intervals, consider resampling")
497
+
498
+ return recommendations
499
+
500
+ def _calculate_overall_score(self, validation_results: Dict) -> float:
501
+ """Calculate overall data quality score"""
502
+ score = 100
503
+
504
+ # Penalties for basic checks
505
+ basic_checks = validation_results.get('basic_checks', {})
506
+ for check_name, check_info in basic_checks.items():
507
+ if not check_info.get('passed', True):
508
+ if check_name == 'min_rows':
509
+ score -= 30
510
+ elif check_name == 'has_target':
511
+ score -= 50
512
+ elif check_name == 'missing_percentage':
513
+ missing_pct = check_info.get('value', 0)
514
+ if missing_pct > 50:
515
+ score -= 40
516
+ elif missing_pct > 20:
517
+ score -= 20
518
+ elif missing_pct > 5:
519
+ score -= 10
520
+ elif check_name == 'duplicates':
521
+ duplicate_pct = check_info.get('value', 0)
522
+ if duplicate_pct > 20:
523
+ score -= 30
524
+ elif duplicate_pct > 10:
525
+ score -= 15
526
+ elif duplicate_pct > 5:
527
+ score -= 5
528
+
529
+ # Penalties for issues
530
+ issues = validation_results.get('issues', {})
531
+ if issues.get('critical'):
532
+ score -= len(issues['critical']) * 20
533
+
534
+ if issues.get('warning'):
535
+ score -= len(issues['warning']) * 5
536
+
537
+ # Bonuses for good metrics
538
+ quality_metrics = validation_results.get('quality_metrics', {})
539
+ target_metrics = quality_metrics.get('target_quality', {})
540
+
541
+ if not target_metrics.get('is_constant', True):
542
+ score += 10
543
+
544
+ if target_metrics.get('missing_percentage', 100) < 1:
545
+ score += 5
546
+
547
+ # Limit score to 0-100 range
548
+ return max(0, min(100, score))
549
+
550
+ def _log_validation_results(self, validation_results: Dict) -> None:
551
+ """Log validation results"""
552
+ stage = validation_results['stage']
553
+ status = validation_results['status']
554
+ score = validation_results['overall_score']
555
+
556
+ logger.info(f"VALIDATION RESULTS ({stage}):")
557
+ logger.info(f" Status: {status}")
558
+ logger.info(f" Overall score: {score}/100")
559
+ logger.info(f" Data shape: {validation_results['data_shape'][0]}x{validation_results['data_shape'][1]}")
560
+
561
+ # Basic checks
562
+ logger.info("\nBASIC CHECKS:")
563
+ for check_name, check_info in validation_results['basic_checks'].items():
564
+ status_icon = "✓" if check_info.get('passed', True) else "✗"
565
+ logger.info(f" {status_icon} {check_name}: {check_info.get('value', 'N/A')}")
566
+
567
+ # Issues
568
+ issues = validation_results['issues']
569
+ if any(issues.values()):
570
+ logger.info("\nDETECTED ISSUES:")
571
+ for severity, issue_list in issues.items():
572
+ if issue_list:
573
+ logger.info(f" {severity.upper()}:")
574
+ for issue in issue_list[:5]: # Show only first 5 issues of each type
575
+ logger.info(f" - {issue}")
576
+ if len(issue_list) > 5:
577
+ logger.info(f" ... and {len(issue_list) - 5} more issues")
578
+ else:
579
+ logger.info("\n✓ No issues detected")
580
+
581
+ # Recommendations
582
+ recommendations = validation_results['recommendations']
583
+ if recommendations:
584
+ logger.info("\nRECOMMENDATIONS:")
585
+ for i, rec in enumerate(recommendations, 1):
586
+ logger.info(f" {i}. {rec}")
587
+
588
+ # Conclusion
589
+ if status == 'PASS':
590
+ logger.info("\n✓ Data passed validation and is ready for use")
591
+ elif status == 'WARNING':
592
+ logger.info("\n⚠ Data requires attention, there are issues to address")
593
+ else:
594
+ logger.info("\n✗ Data requires significant improvement before use")
595
+
596
+ def generate_report(self, stage: str = 'final') -> Dict:
597
+ """Generate detailed validation report"""
598
+ if stage not in self.validation_results:
599
+ return {}
600
+
601
+ report = self.validation_results[stage].copy()
602
+
603
+ # Add metadata
604
+ report['config'] = self.config.to_dict()
605
+ report['validator_version'] = '1.0'
606
+ report['generation_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
607
+
608
+ # Add detailed metrics
609
+ quality_metrics = report.get('quality_metrics', {})
610
+
611
+ if 'numeric_statistics' in quality_metrics:
612
+ # Numeric features summary
613
+ numeric_stats = quality_metrics['numeric_statistics']
614
+ report['numeric_summary'] = {
615
+ 'total_numeric_features': len(numeric_stats),
616
+ 'features_with_high_skewness': sum(1 for s in numeric_stats.values() if abs(s.get('skewness', 0)) > 1),
617
+ 'features_with_high_kurtosis': sum(1 for s in numeric_stats.values() if abs(s.get('kurtosis', 0)) > 3),
618
+ 'features_with_many_zeros': sum(1 for s in numeric_stats.values() if s.get('zeros_percentage', 0) > 50)
619
+ }
620
+
621
+ return report
622
+
623
+ def save_report(self, stage: str = 'final', path: str = None) -> None:
624
+ """Save validation report to file"""
625
+ if stage not in self.validation_results:
626
+ logger.warning(f"Report for stage '{stage}' not found")
627
+ return
628
+
629
+ report = self.generate_report(stage)
630
+
631
+ if path is None:
632
+ path = f'{self.config.results_dir}/reports/validation_report_{stage}.json'
633
+
634
+ # Create directory if needed
635
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
636
+
637
+ # Custom JSON encoder
638
+ class NumpyEncoder(json.JSONEncoder):
639
+ def default(self, obj):
640
+ if isinstance(obj, (np.integer, np.floating)):
641
+ if np.isnan(obj):
642
+ return None
643
+ return float(obj)
644
+ elif isinstance(obj, np.bool_):
645
+ return bool(obj)
646
+ elif isinstance(obj, np.ndarray):
647
+ return obj.tolist()
648
+ elif isinstance(obj, pd.Timestamp):
649
+ return obj.strftime('%Y-%m-%d %H:%M:%S')
650
+ return super().default(obj)
651
+
652
+ with open(path, 'w', encoding='utf-8') as f:
653
+ json.dump(report, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
654
+
655
+ logger.info(f"✓ Validation report saved: {path}")
visualization/__init__.py ADDED
File without changes
visualization/visualization_manager.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 13: VISUALISATION MANAGER (UPDATED)
3
+ # ============================================
4
+ import os
5
+ from datetime import datetime
6
+ import json
7
+ from typing import Dict, List, Optional, Tuple, Union, Any
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from scipy.stats import gaussian_kde
15
+ import matplotlib
16
+ matplotlib.use('Agg') # Use non-display backend
17
+
18
+ from config.config import Config
19
+ import logging
20
+
21
+ # Logging setup
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class VisualisationManager:
27
+ """Class for managing all visualisations"""
28
+
29
+ def __init__(self, config: Config):
30
+ """
31
+ Initialise visualisation manager
32
+
33
+ Parameters:
34
+ -----------
35
+ config : Config
36
+ Experiment configuration
37
+ """
38
+ self.config = config
39
+ self.plots_generated = {}
40
+ self.plot_files = {}
41
+ self.figure_count = 0
42
+
43
+ # Create directory structure for saving plots
44
+ self._create_directory_structure()
45
+
46
+ def _create_directory_structure(self) -> None:
47
+ """Create directory structure for saving plots"""
48
+ base_dir = self.config.results_dir
49
+
50
+ # Main plot directories
51
+ self.plots_dir = os.path.join(base_dir, "plots")
52
+ self.correlations_dir = os.path.join(base_dir, "plots", "correlations")
53
+ self.distributions_dir = os.path.join(base_dir, "plots", "distributions")
54
+ self.features_dir = os.path.join(base_dir, "plots", "features")
55
+ self.time_series_dir = os.path.join(base_dir, "plots", "time_series")
56
+ self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing")
57
+ self.summary_dir = os.path.join(base_dir, "plots", "summary")
58
+ self.reports_dir = os.path.join(base_dir, "reports")
59
+
60
+ # Create directories
61
+ directories = [
62
+ self.plots_dir,
63
+ self.correlations_dir,
64
+ self.distributions_dir,
65
+ self.features_dir,
66
+ self.time_series_dir,
67
+ self.preprocessing_dir,
68
+ self.summary_dir,
69
+ self.reports_dir
70
+ ]
71
+
72
+ for directory in directories:
73
+ os.makedirs(directory, exist_ok=True)
74
+ logger.debug(f"Created directory: {directory}")
75
+
76
+ def _save_figure(self, fig: plt.Figure, filename: str,
77
+ subdirectory: str = None, dpi: int = 300) -> str:
78
+ """
79
+ Save plot and close it
80
+
81
+ Parameters:
82
+ -----------
83
+ fig : matplotlib.figure.Figure
84
+ Plot figure object
85
+ filename : str
86
+ Filename for saving
87
+ subdirectory : str, optional
88
+ Subdirectory for saving
89
+ dpi : int
90
+ Save quality
91
+
92
+ Returns:
93
+ --------
94
+ str : full path to saved file
95
+ """
96
+ if not filename.endswith('.png'):
97
+ filename = f"{filename}.png"
98
+
99
+ if subdirectory:
100
+ save_dir = os.path.join(self.plots_dir, subdirectory)
101
+ os.makedirs(save_dir, exist_ok=True)
102
+ else:
103
+ save_dir = self.plots_dir
104
+
105
+ filepath = os.path.join(save_dir, filename)
106
+
107
+ try:
108
+ fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white')
109
+ logger.info(f"✓ Plot saved: {filepath}")
110
+ except Exception as e:
111
+ logger.error(f"✗ Error saving plot {filename}: {e}")
112
+ filepath = None
113
+
114
+ # Close plot without display
115
+ plt.close(fig)
116
+
117
+ return filepath
118
+
119
+ # ============================================
120
+ # MAIN VISUALISATION METHODS
121
+ # ============================================
122
+
123
+ def create_summary_dashboard(
124
+ self,
125
+ data: pd.DataFrame,
126
+ preprocessing_stages: Dict = None,
127
+ filename: str = "summary_dashboard"
128
+ ) -> str:
129
+ """
130
+ Create summary visualisation dashboard
131
+
132
+ Parameters:
133
+ -----------
134
+ data : pd.DataFrame
135
+ Data for visualisation
136
+ preprocessing_stages : Dict, optional
137
+ Preprocessing stages information
138
+ filename : str
139
+ Filename for saving
140
+
141
+ Returns:
142
+ --------
143
+ str : path to saved file or None if error
144
+ """
145
+ logger.info("\n" + "="*80)
146
+ logger.info("CREATING SUMMARY DASHBOARD")
147
+ logger.info("="*80)
148
+
149
+ target_col = self.config.target_column
150
+
151
+ try:
152
+ # Create large dashboard
153
+ fig = plt.figure(figsize=(20, 24))
154
+ gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3)
155
+
156
+ # 1. Time series of target variable
157
+ ax1 = fig.add_subplot(gs[0, :2])
158
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
159
+ ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7)
160
+ ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold')
161
+ ax1.set_xlabel('Date', fontsize=10)
162
+ ax1.set_ylabel(target_col, fontsize=10)
163
+ ax1.grid(True, alpha=0.3)
164
+ ax1.tick_params(axis='x', rotation=45)
165
+ else:
166
+ ax1.text(0.5, 0.5, 'No time series data available',
167
+ ha='center', va='center', transform=ax1.transAxes)
168
+
169
+ # 2. Target variable distribution
170
+ ax2 = fig.add_subplot(gs[0, 2:])
171
+ if target_col in data.columns:
172
+ values = data[target_col].dropna()
173
+ if len(values) > 0:
174
+ ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green')
175
+ ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold')
176
+ ax2.set_xlabel(target_col, fontsize=10)
177
+ ax2.set_ylabel('Frequency', fontsize=10)
178
+ ax2.grid(True, alpha=0.3)
179
+ else:
180
+ ax2.text(0.5, 0.5, 'No data for distribution',
181
+ ha='center', va='center', transform=ax2.transAxes)
182
+
183
+ # 3. Correlation matrix (top features)
184
+ ax3 = fig.add_subplot(gs[1, :])
185
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
186
+ if len(numeric_cols) > 1:
187
+ display_cols = list(numeric_cols[:15])
188
+ if target_col not in display_cols and target_col in data.columns:
189
+ display_cols = [target_col] + [c for c in display_cols if c != target_col][:14]
190
+
191
+ corr_matrix = data[display_cols].corr()
192
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
193
+
194
+ im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
195
+ ax3.set_title('Correlation Matrix (Top 15 Features)',
196
+ fontsize=12, fontweight='bold')
197
+ ax3.set_xticks(range(len(display_cols)))
198
+ ax3.set_yticks(range(len(display_cols)))
199
+ ax3.set_xticklabels(display_cols, rotation=90, fontsize=8)
200
+ ax3.set_yticklabels(display_cols, fontsize=8)
201
+ plt.colorbar(im, ax=ax3, shrink=0.8)
202
+
203
+ # 4. Seasonal patterns
204
+ ax4 = fig.add_subplot(gs[2, :2])
205
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
206
+ data_copy = data.copy()
207
+ data_copy['month'] = data_copy.index.month
208
+
209
+ monthly_avg = data_copy.groupby('month')[target_col].mean()
210
+ colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg)))
211
+ ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black')
212
+ ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold')
213
+ ax4.set_xlabel('Month', fontsize=10)
214
+ ax4.set_ylabel(f'Average {target_col}', fontsize=10)
215
+ ax4.set_xticks(range(1, 13))
216
+ month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
217
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
218
+ ax4.set_xticklabels(month_names)
219
+ ax4.grid(True, alpha=0.3, axis='y')
220
+
221
+ # 5. Weekly patterns
222
+ ax5 = fig.add_subplot(gs[2, 2:])
223
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
224
+ data_copy = data.copy()
225
+ data_copy['dayofweek'] = data_copy.index.dayofweek
226
+
227
+ daily_avg = data_copy.groupby('dayofweek')[target_col].mean()
228
+ colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg)))
229
+ ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black')
230
+ ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold')
231
+ ax5.set_xlabel('Day of Week', fontsize=10)
232
+ ax5.set_ylabel(f'Average {target_col}', fontsize=10)
233
+ ax5.set_xticks(range(7))
234
+ ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
235
+ ax5.grid(True, alpha=0.3, axis='y')
236
+
237
+ # 6. Trend and seasonality
238
+ ax6 = fig.add_subplot(gs[3, :])
239
+ if target_col in data.columns and len(data) > 30:
240
+ try:
241
+ window_size = min(365, len(data) // 10)
242
+ if window_size >= 7:
243
+ rolling_mean = data[target_col].rolling(window=window_size, center=True).mean()
244
+ rolling_std = data[target_col].rolling(window=window_size, center=True).std()
245
+
246
+ ax6.plot(data.index, data[target_col], alpha=0.5,
247
+ label='Original Series', linewidth=0.5, color='blue')
248
+ ax6.plot(rolling_mean.index, rolling_mean,
249
+ label=f'Rolling Mean ({window_size} days)',
250
+ color='red', linewidth=2)
251
+ ax6.fill_between(rolling_mean.index,
252
+ rolling_mean - rolling_std,
253
+ rolling_mean + rolling_std,
254
+ alpha=0.2, color='red')
255
+
256
+ ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold')
257
+ ax6.set_xlabel('Date', fontsize=10)
258
+ ax6.set_ylabel(target_col, fontsize=10)
259
+ ax6.legend(fontsize=9, loc='upper left')
260
+ ax6.grid(True, alpha=0.3)
261
+ else:
262
+ ax6.text(0.5, 0.5, 'Insufficient data for trend analysis',
263
+ ha='center', va='center', transform=ax6.transAxes)
264
+ except Exception as e:
265
+ logger.warning(f"Error plotting trend: {e}")
266
+ ax6.text(0.5, 0.5, 'Error plotting trend',
267
+ ha='center', va='center', transform=ax6.transAxes)
268
+
269
+ # 7. Preprocessing statistics
270
+ if preprocessing_stages:
271
+ ax7 = fig.add_subplot(gs[4, :2])
272
+
273
+ stages = list(preprocessing_stages.keys())
274
+ values = list(preprocessing_stages.values())
275
+
276
+ colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages)))
277
+ bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black')
278
+ ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold')
279
+ ax7.set_xlabel('Processing Stage', fontsize=10)
280
+ ax7.set_ylabel('Value', fontsize=10)
281
+ ax7.set_xticks(range(len(stages)))
282
+ ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages],
283
+ rotation=45, ha='right', fontsize=9)
284
+ ax7.grid(True, alpha=0.3, axis='y')
285
+
286
+ # Add values on bars
287
+ for bar, value in zip(bars, values):
288
+ height = bar.get_height()
289
+ ax7.text(bar.get_x() + bar.get_width()/2., height,
290
+ f'{value:.2f}', ha='center', va='bottom', fontsize=8)
291
+
292
+ # 8. Data information
293
+ ax8 = fig.add_subplot(gs[4, 2:])
294
+ ax8.axis('off')
295
+
296
+ info_text = []
297
+ info_text.append("GENERAL CHARACTERISTICS:")
298
+ info_text.append(f"• Number of records: {len(data):,}")
299
+ info_text.append(f"• Number of features: {len(data.columns)}")
300
+
301
+ if isinstance(data.index, pd.DatetimeIndex):
302
+ info_text.append(f"• Period: {data.index.min().strftime('%Y-%m-%d')} - "
303
+ f"{data.index.max().strftime('%Y-%m-%d')}")
304
+ info_text.append(f"• Days of data: {(data.index.max() - data.index.min()).days}")
305
+
306
+ if target_col in data.columns:
307
+ target_stats = data[target_col].describe()
308
+ info_text.append(f"\nTARGET VARIABLE '{target_col}':")
309
+ info_text.append(f"• Mean: {target_stats['mean']:.2f}")
310
+ info_text.append(f"• Standard deviation: {target_stats['std']:.2f}")
311
+ info_text.append(f"• Minimum: {target_stats['min']:.2f}")
312
+ info_text.append(f"• 25%: {target_stats['25%']:.2f}")
313
+ info_text.append(f"• 50% (median): {target_stats['50%']:.2f}")
314
+ info_text.append(f"• 75%: {target_stats['75%']:.2f}")
315
+ info_text.append(f"• Maximum: {target_stats['max']:.2f}")
316
+
317
+ info_text.append(f"\nDATA TYPES:")
318
+ for dtype, count in data.dtypes.value_counts().items():
319
+ info_text.append(f"• {dtype}: {count} columns")
320
+
321
+ missing_info = data.isnull().sum()
322
+ missing_total = missing_info.sum()
323
+ missing_percent = missing_total / data.size * 100
324
+ info_text.append(f"\nMISSING VALUES:")
325
+ info_text.append(f"• Total missing: {missing_total:,}")
326
+ info_text.append(f"• Missing percentage: {missing_percent:.2f}%")
327
+
328
+ if missing_total > 0:
329
+ top_missing = missing_info.nlargest(5)
330
+ info_text.append(f"• Top 5 columns with missing values:")
331
+ for col, count in top_missing.items():
332
+ percent = count / len(data) * 100
333
+ info_text.append(f" {col}: {count} ({percent:.1f}%)")
334
+
335
+ ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes,
336
+ fontsize=8, verticalalignment='top', fontfamily='monospace')
337
+
338
+ # 9. Autocorrelation plot
339
+ ax9 = fig.add_subplot(gs[5, :2])
340
+ if target_col in data.columns:
341
+ try:
342
+ series = data[target_col].dropna()
343
+ if len(series) > 50:
344
+ plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05)
345
+ ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold')
346
+ ax9.set_xlabel('Lag', fontsize=10)
347
+ ax9.set_ylabel('Autocorrelation', fontsize=10)
348
+ ax9.grid(True, alpha=0.3)
349
+ else:
350
+ ax9.text(0.5, 0.5, 'Insufficient data for ACF',
351
+ ha='center', va='center', transform=ax9.transAxes)
352
+ except Exception as e:
353
+ logger.warning(f"Error plotting ACF: {e}")
354
+ ax9.text(0.5, 0.5, 'Error calculating ACF',
355
+ ha='center', va='center', transform=ax9.transAxes)
356
+
357
+ # 10. Partial autocorrelation plot
358
+ ax10 = fig.add_subplot(gs[5, 2:])
359
+ if target_col in data.columns:
360
+ try:
361
+ series = data[target_col].dropna()
362
+ if len(series) > 50:
363
+ plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05)
364
+ ax10.set_title('Partial Autocorrelation Function (PACF)',
365
+ fontsize=12, fontweight='bold')
366
+ ax10.set_xlabel('Lag', fontsize=10)
367
+ ax10.set_ylabel('Partial Autocorrelation', fontsize=10)
368
+ ax10.grid(True, alpha=0.3)
369
+ else:
370
+ ax10.text(0.5, 0.5, 'Insufficient data for PACF',
371
+ ha='center', va='center', transform=ax10.transAxes)
372
+ except Exception as e:
373
+ logger.warning(f"Error plotting PACF: {e}")
374
+ ax10.text(0.5, 0.5, 'Error calculating PACF',
375
+ ha='center', va='center', transform=ax10.transAxes)
376
+
377
+ plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98)
378
+ plt.tight_layout()
379
+
380
+ # Save
381
+ filepath = self._save_figure(fig, filename, "summary")
382
+ self.plot_files['summary_dashboard'] = filepath
383
+ return filepath
384
+
385
+ except Exception as e:
386
+ logger.error(f"Error creating summary dashboard: {e}")
387
+ return None
388
+
389
+ # ============================================
390
+ # SPECIFIC METHODS FOR SAVING YOUR PLOTS
391
+ # ============================================
392
+
393
+ def save_data_split_plot(self, filename: str = "data_split.png") -> str:
394
+ """
395
+ Save data split plot
396
+
397
+ Parameters:
398
+ -----------
399
+ filename : str
400
+ Filename for saving
401
+
402
+ Returns:
403
+ --------
404
+ str : path to saved file
405
+ """
406
+ try:
407
+ fig = plt.gcf() # Get current figure
408
+ filepath = self._save_figure(fig, filename, "time_series")
409
+ self.plot_files['data_split'] = filepath
410
+ return filepath
411
+ except Exception as e:
412
+ logger.error(f"Error saving data_split plot: {e}")
413
+ return None
414
+
415
+ def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str:
416
+ """
417
+ Save feature selection correlation plot
418
+
419
+ Parameters:
420
+ -----------
421
+ filename : str
422
+ Filename for saving
423
+
424
+ Returns:
425
+ --------
426
+ str : path to saved file
427
+ """
428
+ try:
429
+ fig = plt.gcf() # Get current figure
430
+ filepath = self._save_figure(fig, filename, "correlations")
431
+ self.plot_files['feature_selection_correlation'] = filepath
432
+ return filepath
433
+ except Exception as e:
434
+ logger.error(f"Error saving feature_selection_correlation plot: {e}")
435
+ return None
436
+
437
+ def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str:
438
+ """
439
+ Save missing values analysis plot
440
+
441
+ Parameters:
442
+ -----------
443
+ filename : str
444
+ Filename for saving
445
+
446
+ Returns:
447
+ --------
448
+ str : path to saved file
449
+ """
450
+ try:
451
+ fig = plt.gcf() # Get current figure
452
+ filepath = self._save_figure(fig, filename, "preprocessing")
453
+ self.plot_files['missing_values_analysis'] = filepath
454
+ return filepath
455
+ except Exception as e:
456
+ logger.error(f"Error saving missing_values_analysis plot: {e}")
457
+ return None
458
+
459
+ def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str:
460
+ """
461
+ Save outlier handling results plot
462
+
463
+ Parameters:
464
+ -----------
465
+ filename : str
466
+ Filename for saving
467
+
468
+ Returns:
469
+ --------
470
+ str : path to saved file
471
+ """
472
+ try:
473
+ fig = plt.gcf() # Get current figure
474
+ filepath = self._save_figure(fig, filename, "preprocessing")
475
+ self.plot_files['outlier_handling_results'] = filepath
476
+ return filepath
477
+ except Exception as e:
478
+ logger.error(f"Error saving outlier_handling_results plot: {e}")
479
+ return None
480
+
481
+ def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str:
482
+ """
483
+ Save outliers analysis plot
484
+
485
+ Parameters:
486
+ -----------
487
+ filename : str
488
+ Filename for saving
489
+
490
+ Returns:
491
+ --------
492
+ str : path to saved file
493
+ """
494
+ try:
495
+ fig = plt.gcf() # Get current figure
496
+ filepath = self._save_figure(fig, filename, "preprocessing")
497
+ self.plot_files['outliers_analysis'] = filepath
498
+ return filepath
499
+ except Exception as e:
500
+ logger.error(f"Error saving outliers_analysis plot: {e}")
501
+ return None
502
+
503
+ def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str:
504
+ """
505
+ Save scaling results plot
506
+
507
+ Parameters:
508
+ -----------
509
+ filename : str
510
+ Filename for saving
511
+
512
+ Returns:
513
+ --------
514
+ str : path to saved file
515
+ """
516
+ try:
517
+ fig = plt.gcf() # Get current figure
518
+ filepath = self._save_figure(fig, filename, "preprocessing")
519
+ self.plot_files['scaling_results'] = filepath
520
+ return filepath
521
+ except Exception as e:
522
+ logger.error(f"Error saving scaling_results plot: {e}")
523
+ return None
524
+
525
+ def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str:
526
+ """
527
+ Save stationarity analysis plot
528
+
529
+ Parameters:
530
+ -----------
531
+ filename : str
532
+ Filename for saving
533
+
534
+ Returns:
535
+ --------
536
+ str : path to saved file
537
+ """
538
+ try:
539
+ fig = plt.gcf() # Get current figure
540
+ filepath = self._save_figure(fig, filename, "time_series")
541
+ self.plot_files['stationarity_analysis'] = filepath
542
+ return filepath
543
+ except Exception as e:
544
+ logger.error(f"Error saving stationarity_analysis plot: {e}")
545
+ return None
546
+
547
+ def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str:
548
+ """
549
+ Save temporal outliers plot
550
+
551
+ Parameters:
552
+ -----------
553
+ filename : str
554
+ Filename for saving
555
+
556
+ Returns:
557
+ --------
558
+ str : path to saved file
559
+ """
560
+ try:
561
+ fig = plt.gcf() # Get current figure
562
+ filepath = self._save_figure(fig, filename, "time_series")
563
+ self.plot_files['temporal_outliers'] = filepath
564
+ return filepath
565
+ except Exception as e:
566
+ logger.error(f"Error saving temporal_outliers plot: {e}")
567
+ return None
568
+
569
+ # ============================================
570
+ # UNIVERSAL METHOD FOR SAVING ANY PLOT
571
+ # ============================================
572
+
573
+ def save_current_plot(self, filename: str, subdirectory: str = None) -> str:
574
+ """
575
+ Universal method for saving current plot
576
+
577
+ Parameters:
578
+ -----------
579
+ filename : str
580
+ Filename for saving
581
+ subdirectory : str, optional
582
+ Subdirectory for saving
583
+
584
+ Returns:
585
+ --------
586
+ str : path to saved file
587
+ """
588
+ try:
589
+ fig = plt.gcf() # Get current figure
590
+ filepath = self._save_figure(fig, filename, subdirectory)
591
+
592
+ # Save plot information
593
+ plot_key = filename.replace('.png', '').replace('.jpg', '')
594
+ self.plot_files[plot_key] = filepath
595
+
596
+ return filepath
597
+ except Exception as e:
598
+ logger.error(f"Error saving plot {filename}: {e}")
599
+ return None
600
+
601
+ # ============================================
602
+ # ADDITIONAL VISUALISATION METHODS
603
+ # ============================================
604
+
605
+ def create_feature_importance_plot(
606
+ self,
607
+ feature_importance: Dict,
608
+ top_n: int = 20,
609
+ filename: str = "feature_importance"
610
+ ) -> str:
611
+ """
612
+ Create feature importance plot
613
+
614
+ Parameters:
615
+ -----------
616
+ feature_importance : Dict
617
+ Dictionary with feature importance
618
+ top_n : int
619
+ Number of top features to display
620
+ filename : str
621
+ Filename for saving
622
+
623
+ Returns:
624
+ --------
625
+ str : path to saved file or None if error
626
+ """
627
+ if not feature_importance:
628
+ logger.warning("No feature importance data for visualisation")
629
+ return None
630
+
631
+ try:
632
+ # Convert to Series and sort
633
+ importance_series = pd.Series(feature_importance).sort_values(ascending=False)
634
+ top_features = importance_series.head(top_n)
635
+
636
+ # Create plot
637
+ fig, ax = plt.subplots(figsize=(12, 8))
638
+
639
+ y_pos = np.arange(len(top_features))
640
+ colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features)))
641
+
642
+ bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black')
643
+ ax.set_yticks(y_pos)
644
+ ax.set_yticklabels(top_features.index, fontsize=10)
645
+ ax.invert_yaxis()
646
+ ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold')
647
+ ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold')
648
+ ax.grid(True, alpha=0.3, axis='x')
649
+
650
+ # Add values on bars
651
+ for i, (bar, value) in enumerate(zip(bars, top_features.values)):
652
+ width = bar.get_width()
653
+ ax.text(width * 1.01, bar.get_y() + bar.get_height()/2,
654
+ f'{value:.4f}', va='center', fontsize=9, fontweight='bold')
655
+
656
+ # Add additional information
657
+ plt.text(0.02, 0.98, f'Total features: {len(importance_series)}',
658
+ transform=fig.transFigure, fontsize=9, verticalalignment='top')
659
+
660
+ plt.tight_layout()
661
+
662
+ # Save
663
+ filepath = self._save_figure(fig, filename, "features")
664
+ self.plot_files['feature_importance'] = filepath
665
+ return filepath
666
+
667
+ except Exception as e:
668
+ logger.error(f"Error creating feature importance plot: {e}")
669
+ return None
670
+
671
+ def create_correlation_heatmap(
672
+ self,
673
+ data: pd.DataFrame,
674
+ top_n: int = 20,
675
+ filename: str = "correlation_heatmap"
676
+ ) -> Tuple[str, Optional[str]]:
677
+ """
678
+ Create correlation heatmap
679
+
680
+ Parameters:
681
+ -----------
682
+ data : pd.DataFrame
683
+ Data for analysis
684
+ top_n : int
685
+ Number of top features to display
686
+ filename : str
687
+ Filename for saving
688
+
689
+ Returns:
690
+ --------
691
+ Tuple[str, Optional[str]]:
692
+ (path to main heatmap, path to target correlation heatmap)
693
+ """
694
+ target_col = self.config.target_column
695
+
696
+ try:
697
+ numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
698
+
699
+ if len(numeric_cols) < 2:
700
+ logger.warning("Insufficient numeric features for correlation analysis")
701
+ return None, None
702
+
703
+ # Create two heatmaps
704
+
705
+ # 1. Main correlation heatmap between all features
706
+ main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename)
707
+
708
+ # 2. Target correlation heatmap
709
+ target_filepath = None
710
+ if target_col in data.columns and target_col in numeric_cols:
711
+ target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename)
712
+
713
+ return main_filepath, target_filepath
714
+
715
+ except Exception as e:
716
+ logger.error(f"Error creating correlation heatmap: {e}")
717
+ return None, None
718
+
719
+ def _create_main_correlation_heatmap(
720
+ self,
721
+ data: pd.DataFrame,
722
+ numeric_cols: List[str],
723
+ top_n: int,
724
+ filename: str
725
+ ) -> str:
726
+ """Create main correlation heatmap"""
727
+ # Limit number of features for better readability
728
+ if len(numeric_cols) > top_n:
729
+ # Select features with highest variance
730
+ variances = data[numeric_cols].var().sort_values(ascending=False)
731
+ selected_cols = variances.head(top_n).index.tolist()
732
+ else:
733
+ selected_cols = numeric_cols
734
+
735
+ # Calculate correlation
736
+ corr_matrix = data[selected_cols].corr()
737
+
738
+ fig, ax = plt.subplots(figsize=(14, 12))
739
+
740
+ # Mask for upper triangle
741
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
742
+
743
+ # Create heatmap
744
+ sns.heatmap(
745
+ corr_matrix,
746
+ annot=True,
747
+ fmt='.2f',
748
+ cmap='coolwarm',
749
+ center=0,
750
+ square=True,
751
+ mask=mask,
752
+ cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'},
753
+ linewidths=0.5,
754
+ linecolor='white',
755
+ ax=ax,
756
+ annot_kws={'size': 8}
757
+ )
758
+
759
+ ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})',
760
+ fontsize=14, fontweight='bold', pad=20)
761
+
762
+ plt.tight_layout()
763
+
764
+ # Save
765
+ filepath = self._save_figure(fig, filename, "correlations")
766
+ self.plot_files['correlation_heatmap_main'] = filepath
767
+ return filepath
768
+
769
+ def _create_target_correlation_heatmap(
770
+ self,
771
+ data: pd.DataFrame,
772
+ target_col: str,
773
+ numeric_cols: List[str],
774
+ filename: str
775
+ ) -> str:
776
+ """Create target correlation heatmap"""
777
+ # Calculate correlations with target variable
778
+ correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False)
779
+
780
+ # Exclude target variable itself
781
+ correlations = correlations[correlations.index != target_col]
782
+
783
+ # Take top 15 features
784
+ top_features = correlations.head(15)
785
+
786
+ fig, ax = plt.subplots(figsize=(10, 8))
787
+
788
+ colors = ['red' if x < 0 else 'green' for x in top_features.values]
789
+ bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black')
790
+
791
+ ax.set_yticks(range(len(top_features)))
792
+ ax.set_yticklabels(top_features.index, fontsize=10)
793
+ ax.invert_yaxis()
794
+ ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold')
795
+ ax.set_title(f'Feature Correlations with Target Variable "{target_col}"',
796
+ fontsize=14, fontweight='bold', pad=20)
797
+ ax.grid(True, alpha=0.3, axis='x')
798
+ ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
799
+
800
+ # Add values on bars
801
+ for bar, value in zip(bars, top_features.values):
802
+ width = bar.get_width()
803
+ ax.text(width + (0.01 if width >= 0 else -0.04),
804
+ bar.get_y() + bar.get_height()/2,
805
+ f'{value:.3f}',
806
+ va='center',
807
+ ha='left' if width >= 0 else 'right',
808
+ fontsize=9,
809
+ fontweight='bold',
810
+ color='black')
811
+
812
+ plt.tight_layout()
813
+
814
+ # Save
815
+ target_filename = f"{filename}_with_target"
816
+ filepath = self._save_figure(fig, target_filename, "correlations")
817
+ self.plot_files['correlation_with_target'] = filepath
818
+ return filepath
819
+
820
+ def create_distribution_comparison(
821
+ self,
822
+ original_data: pd.DataFrame,
823
+ processed_data: pd.DataFrame,
824
+ columns: List[str] = None,
825
+ max_columns: int = 12,
826
+ filename: str = "distribution_comparison"
827
+ ) -> str:
828
+ """
829
+ Compare distributions before and after processing
830
+
831
+ Parameters:
832
+ -----------
833
+ original_data : pd.DataFrame
834
+ Original data
835
+ processed_data : pd.DataFrame
836
+ Processed data
837
+ columns : List[str], optional
838
+ List of columns to compare
839
+ max_columns : int
840
+ Maximum number of columns to display
841
+ filename : str
842
+ Filename for saving
843
+
844
+ Returns:
845
+ --------
846
+ str : path to saved file or None if error
847
+ """
848
+ try:
849
+ if columns is None:
850
+ # Select numeric columns common to both datasets
851
+ numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns
852
+ numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns
853
+ common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed))
854
+
855
+ # Sort by variance in original data
856
+ variances = original_data[common_cols].var().sort_values(ascending=False)
857
+ columns = variances.head(max_columns).index.tolist()
858
+
859
+ n_cols = min(4, len(columns))
860
+ n_rows = (len(columns) + n_cols - 1) // n_cols
861
+
862
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5))
863
+ fig.suptitle('Distribution Comparison Before and After Processing',
864
+ fontsize=16, fontweight='bold', y=0.98)
865
+
866
+ if n_rows == 1 and n_cols == 1:
867
+ axes = np.array([axes])
868
+ axes = axes.flat if hasattr(axes, 'flat') else [axes]
869
+
870
+ for idx, col in enumerate(columns):
871
+ if idx >= len(axes):
872
+ break
873
+
874
+ ax = axes[idx]
875
+
876
+ if col in original_data.columns and col in processed_data.columns:
877
+ original_values = original_data[col].dropna()
878
+ processed_values = processed_data[col].dropna()
879
+
880
+ if len(original_values) > 0 and len(processed_values) > 0:
881
+ # Use common bins for comparison
882
+ all_values = pd.concat([original_values, processed_values])
883
+ bins = np.histogram_bin_edges(all_values, bins=30)
884
+
885
+ # Histograms
886
+ ax.hist(original_values, bins=bins, alpha=0.5,
887
+ label='Before Processing', density=True, color='blue')
888
+ ax.hist(processed_values, bins=bins, alpha=0.5,
889
+ label='After Processing', density=True, color='orange')
890
+
891
+ # Add KDE
892
+ try:
893
+ if len(original_values) > 10:
894
+ kde_original = gaussian_kde(original_values)
895
+ x_range = np.linspace(original_values.min(), original_values.max(), 100)
896
+ ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8)
897
+
898
+ if len(processed_values) > 10:
899
+ kde_processed = gaussian_kde(processed_values)
900
+ x_range = np.linspace(processed_values.min(), processed_values.max(), 100)
901
+ ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8)
902
+ except:
903
+ pass
904
+
905
+ # Add statistics
906
+ stats_text = []
907
+ if len(original_values) > 0:
908
+ stats_text.append(f"Before: μ={original_values.mean():.2f}, σ={original_values.std():.2f}")
909
+ if len(processed_values) > 0:
910
+ stats_text.append(f"After: μ={processed_values.mean():.2f}, σ={processed_values.std():.2f}")
911
+
912
+ if stats_text:
913
+ ax.text(0.02, 0.98, '\n'.join(stats_text),
914
+ transform=ax.transAxes, fontsize=8,
915
+ verticalalignment='top',
916
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
917
+
918
+ ax.set_title(f'{col}', fontsize=11, fontweight='bold')
919
+ ax.set_xlabel('Value', fontsize=9)
920
+ ax.set_ylabel('Density', fontsize=9)
921
+ ax.legend(fontsize=8)
922
+ ax.grid(True, alpha=0.3)
923
+ else:
924
+ ax.text(0.5, 0.5, 'No data',
925
+ ha='center', va='center', transform=ax.transAxes)
926
+ else:
927
+ ax.text(0.5, 0.5, 'Column not found',
928
+ ha='center', va='center', transform=ax.transAxes)
929
+
930
+ # Hide unused subplots
931
+ for idx in range(len(columns), len(axes)):
932
+ axes[idx].set_visible(False)
933
+
934
+ plt.tight_layout()
935
+
936
+ # Save
937
+ filepath = self._save_figure(fig, filename, "distributions")
938
+ self.plot_files['distribution_comparison'] = filepath
939
+ return filepath
940
+
941
+ except Exception as e:
942
+ logger.error(f"Error creating distribution comparison: {e}")
943
+ return None
944
+
945
+ def create_time_series_decomposition_plot(
946
+ self,
947
+ decomposition_result: Dict,
948
+ filename: str = "time_series_decomposition"
949
+ ) -> str:
950
+ """
951
+ Visualise time series decomposition
952
+
953
+ Parameters:
954
+ -----------
955
+ decomposition_result : Dict
956
+ Decomposition results
957
+ filename : str
958
+ Filename for saving
959
+
960
+ Returns:
961
+ --------
962
+ str : path to saved file or None if error
963
+ """
964
+ target_col = self.config.target_column
965
+
966
+ try:
967
+ fig, axes = plt.subplots(4, 1, figsize=(14, 10))
968
+ fig.suptitle(f'Time Series Decomposition: {target_col}',
969
+ fontsize=16, fontweight='bold', y=0.98)
970
+
971
+ # Original series
972
+ if 'observed' in decomposition_result:
973
+ observed = decomposition_result['observed']
974
+ axes[0].plot(observed, color='blue', linewidth=1.5)
975
+ axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold')
976
+ axes[0].grid(True, alpha=0.3)
977
+ axes[0].set_title('Original Time Series', fontsize=12)
978
+
979
+ # Trend
980
+ if 'trend' in decomposition_result and decomposition_result['trend'] is not None:
981
+ trend = decomposition_result['trend']
982
+ axes[1].plot(trend, color='red', linewidth=2)
983
+ axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold')
984
+ axes[1].grid(True, alpha=0.3)
985
+ axes[1].set_title('Trend Component', fontsize=12)
986
+
987
+ # Seasonality
988
+ if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None:
989
+ seasonal = decomposition_result['seasonal']
990
+ axes[2].plot(seasonal, color='green', linewidth=1.5)
991
+ axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold')
992
+ axes[2].grid(True, alpha=0.3)
993
+ axes[2].set_title('Seasonal Component', fontsize=12)
994
+
995
+ # Residuals
996
+ if 'residual' in decomposition_result and decomposition_result['residual'] is not None:
997
+ residual = decomposition_result['residual']
998
+ axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7)
999
+ axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold')
1000
+ axes[3].set_xlabel('Date', fontsize=11, fontweight='bold')
1001
+ axes[3].grid(True, alpha=0.3)
1002
+ axes[3].set_title('Residual Component', fontsize=12)
1003
+
1004
+ # Add residual statistics
1005
+ if len(residual) > 0:
1006
+ stats_text = (f"Mean: {residual.mean():.4f}\n"
1007
+ f"Std: {residual.std():.4f}\n"
1008
+ f"Min: {residual.min():.4f}\n"
1009
+ f"Max: {residual.max():.4f}")
1010
+ axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes,
1011
+ fontsize=8, verticalalignment='top',
1012
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
1013
+
1014
+ plt.tight_layout()
1015
+
1016
+ # Save
1017
+ filepath = self._save_figure(fig, filename, "time_series")
1018
+ self.plot_files['time_series_decomposition'] = filepath
1019
+ return filepath
1020
+
1021
+ except Exception as e:
1022
+ logger.error(f"Error creating time series decomposition: {e}")
1023
+ return None
1024
+
1025
+ def create_data_quality_report(
1026
+ self,
1027
+ validation_results: Dict,
1028
+ filename: str = "data_quality_report"
1029
+ ) -> str:
1030
+ """
1031
+ Create visual data quality report
1032
+
1033
+ Parameters:
1034
+ -----------
1035
+ validation_results : Dict
1036
+ Validation results
1037
+ filename : str
1038
+ Filename for saving
1039
+
1040
+ Returns:
1041
+ --------
1042
+ str : path to saved file or None if error
1043
+ """
1044
+ try:
1045
+ fig = plt.figure(figsize=(16, 12))
1046
+ fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98)
1047
+
1048
+ # Use GridSpec for more complex layout
1049
+ gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
1050
+
1051
+ # 1. Quality radar chart (top left)
1052
+ ax1 = fig.add_subplot(gs[0, 0], projection='polar')
1053
+
1054
+ categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness']
1055
+
1056
+ # Extract values from validation results
1057
+ if 'quality_metrics' in validation_results:
1058
+ values = [
1059
+ validation_results['quality_metrics'].get('size_score', 0.5),
1060
+ validation_results['quality_metrics'].get('missing_score', 0.5),
1061
+ validation_results['quality_metrics'].get('duplicates_score', 0.5),
1062
+ validation_results['quality_metrics'].get('stability_score', 0.5),
1063
+ validation_results['quality_metrics'].get('informativeness_score', 0.5)
1064
+ ]
1065
+ else:
1066
+ values = [0.8, 0.7, 0.9, 0.6, 0.8]
1067
+
1068
+ N = len(categories)
1069
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
1070
+ angles += angles[:1]
1071
+ values += values[:1]
1072
+
1073
+ ax1.plot(angles, values, 'o-', linewidth=2, color='blue')
1074
+ ax1.fill(angles, values, alpha=0.25, color='blue')
1075
+ ax1.set_xticks(angles[:-1])
1076
+ ax1.set_xticklabels(categories, fontsize=10)
1077
+ ax1.set_ylim(0, 1)
1078
+ ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold')
1079
+ ax1.grid(True)
1080
+
1081
+ # 2. Check status (top right)
1082
+ ax2 = fig.add_subplot(gs[0, 1])
1083
+
1084
+ basic_checks = validation_results.get('basic_checks', {})
1085
+ checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False))
1086
+ checks_total = len(basic_checks)
1087
+ checks_failed = checks_total - checks_passed
1088
+
1089
+ if checks_total > 0:
1090
+ colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B',
1091
+ '#FF6B6B' if checks_failed > 0 else '#4CAF50']
1092
+ bars = ax2.bar(['Passed', 'Failed'],
1093
+ [checks_passed, checks_failed],
1094
+ color=colors, edgecolor='black')
1095
+
1096
+ ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}',
1097
+ fontsize=12, fontweight='bold')
1098
+ ax2.set_ylabel('Number of Checks', fontsize=10)
1099
+ ax2.grid(True, alpha=0.3, axis='y')
1100
+
1101
+ # Add values on bars
1102
+ for bar, value in zip(bars, [checks_passed, checks_failed]):
1103
+ height = bar.get_height()
1104
+ ax2.text(bar.get_x() + bar.get_width()/2., height,
1105
+ f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
1106
+ else:
1107
+ ax2.text(0.5, 0.5, 'No check data available',
1108
+ ha='center', va='center', transform=ax2.transAxes)
1109
+ ax2.set_title('Basic Checks', fontsize=12, fontweight='bold')
1110
+
1111
+ # 3. Overall score (top right)
1112
+ ax3 = fig.add_subplot(gs[0, 2])
1113
+
1114
+ overall_score = validation_results.get('overall_score', 0)
1115
+ status = validation_results.get('status', 'UNKNOWN')
1116
+
1117
+ # Score pie chart
1118
+ sizes = [overall_score, 100 - overall_score]
1119
+
1120
+ if overall_score >= 80:
1121
+ colors = ['#4CAF50', '#E0E0E0'] # Green
1122
+ elif overall_score >= 60:
1123
+ colors = ['#FFC107', '#E0E0E0'] # Yellow
1124
+ else:
1125
+ colors = ['#F44336', '#E0E0E0'] # Red
1126
+
1127
+ wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90,
1128
+ autopct='%1.1f%%', pctdistance=0.85)
1129
+
1130
+ # Central text
1131
+ status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'}
1132
+ status_color = status_colors.get(status, '#757575')
1133
+
1134
+ ax3.text(0, 0, f'{overall_score}/100\n{status}',
1135
+ ha='center', va='center', fontsize=14, fontweight='bold',
1136
+ color=status_color)
1137
+ ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold')
1138
+
1139
+ # 4. Issue distribution by type (left middle)
1140
+ ax4 = fig.add_subplot(gs[1, 0])
1141
+
1142
+ issues = validation_results.get('issues', {})
1143
+ issue_counts = {
1144
+ 'Critical': len(issues.get('critical', [])),
1145
+ 'Warnings': len(issues.get('warning', [])),
1146
+ 'Informational': len(issues.get('info', []))
1147
+ }
1148
+
1149
+ if any(issue_counts.values()):
1150
+ colors = ['#F44336', '#FF9800', '#2196F3']
1151
+ bars = ax4.bar(issue_counts.keys(), issue_counts.values(),
1152
+ color=colors, edgecolor='black')
1153
+
1154
+ ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold')
1155
+ ax4.set_ylabel('Number of Issues', fontsize=10)
1156
+ ax4.tick_params(axis='x', rotation=45)
1157
+ ax4.grid(True, alpha=0.3, axis='y')
1158
+
1159
+ # Add values on bars
1160
+ for bar, value in zip(bars, issue_counts.values()):
1161
+ height = bar.get_height()
1162
+ ax4.text(bar.get_x() + bar.get_width()/2., height,
1163
+ f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
1164
+ else:
1165
+ ax4.text(0.5, 0.5, 'No issues detected',
1166
+ ha='center', va='center', transform=ax4.transAxes, fontsize=12)
1167
+ ax4.set_title('Data Issues', fontsize=12, fontweight='bold')
1168
+
1169
+ # 5. Detailed information (remaining cells)
1170
+ ax5 = fig.add_subplot(gs[1:, 1:])
1171
+ ax5.axis('off')
1172
+
1173
+ # Form text report
1174
+ report_text = []
1175
+ report_text.append("DETAILED REPORT:")
1176
+ report_text.append("=" * 40)
1177
+
1178
+ # Basic information
1179
+ report_text.append("\nBASIC INFORMATION:")
1180
+ report_text.append(f"• Overall score: {overall_score}/100")
1181
+ report_text.append(f"• Status: {status}")
1182
+ report_text.append(f"• Checks passed: {checks_passed}/{checks_total}")
1183
+
1184
+ # Check details
1185
+ if basic_checks:
1186
+ report_text.append("\nCHECK DETAILS:")
1187
+ for check_name, check_result in basic_checks.items():
1188
+ status_icon = "✓" if check_result.get('passed', False) else "✗"
1189
+ report_text.append(f"• {status_icon} {check_name}: {check_result.get('message', '')}")
1190
+
1191
+ # Issues
1192
+ if any(issue_counts.values()):
1193
+ report_text.append("\nDETECTED ISSUES:")
1194
+
1195
+ if issue_counts['Critical'] > 0:
1196
+ report_text.append("\nCRITICAL:")
1197
+ for issue in issues.get('critical', []):
1198
+ report_text.append(f" • {issue}")
1199
+
1200
+ if issue_counts['Warnings'] > 0:
1201
+ report_text.append("\nWARNINGS:")
1202
+ for issue in issues.get('warning', []):
1203
+ report_text.append(f" • {issue}")
1204
+
1205
+ if issue_counts['Informational'] > 0:
1206
+ report_text.append("\nINFORMATIONAL:")
1207
+ for issue in issues.get('info', []):
1208
+ report_text.append(f" • {issue}")
1209
+
1210
+ # Recommendations
1211
+ recommendations = validation_results.get('recommendations', [])
1212
+ if recommendations:
1213
+ report_text.append("\nRECOMMENDATIONS:")
1214
+ for i, rec in enumerate(recommendations, 1):
1215
+ report_text.append(f"{i}. {rec}")
1216
+
1217
+ ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes,
1218
+ fontsize=9, verticalalignment='top', fontfamily='monospace',
1219
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1))
1220
+
1221
+ plt.tight_layout()
1222
+
1223
+ # Save
1224
+ filepath = self._save_figure(fig, filename, "reports")
1225
+ self.plot_files['data_quality_report'] = filepath
1226
+ return filepath
1227
+
1228
+ except Exception as e:
1229
+ logger.error(f"Error creating data quality report: {e}")
1230
+ return None
1231
+
1232
+ # ============================================
1233
+ # METHODS FOR BATCH SAVING
1234
+ # ============================================
1235
+
1236
+ def save_all_preprocessing_plots(self) -> Dict[str, str]:
1237
+ """
1238
+ Save all preprocessing plots from current session
1239
+
1240
+ Returns:
1241
+ --------
1242
+ Dict[str, str] : dictionary with paths to saved plots
1243
+ """
1244
+ logger.info("Saving all preprocessing plots...")
1245
+
1246
+ plots_saved = {}
1247
+
1248
+ # Get all open figures
1249
+ figure_numbers = plt.get_fignums()
1250
+
1251
+ if not figure_numbers:
1252
+ logger.warning("No open plots to save")
1253
+ return plots_saved
1254
+
1255
+ # Save each plot
1256
+ for fig_num in figure_numbers:
1257
+ fig = plt.figure(fig_num)
1258
+ filename = f"preprocessing_plot_{fig_num}.png"
1259
+ filepath = self._save_figure(fig, filename, "preprocessing")
1260
+ if filepath:
1261
+ plots_saved[f"plot_{fig_num}"] = filepath
1262
+
1263
+ logger.info(f"Saved {len(plots_saved)} preprocessing plots")
1264
+ return plots_saved
1265
+
1266
+ def create_all_visualizations(
1267
+ self,
1268
+ data: pd.DataFrame,
1269
+ processed_data: pd.DataFrame = None,
1270
+ feature_importance: Dict = None,
1271
+ decomposition_result: Dict = None,
1272
+ validation_results: Dict = None,
1273
+ preprocessing_stages: Dict = None
1274
+ ) -> Dict[str, str]:
1275
+ """
1276
+ Create all visualisations in one call
1277
+
1278
+ Parameters:
1279
+ -----------
1280
+ data : pd.DataFrame
1281
+ Original data
1282
+ processed_data : pd.DataFrame, optional
1283
+ Processed data
1284
+ feature_importance : Dict, optional
1285
+ Feature importance
1286
+ decomposition_result : Dict, optional
1287
+ Decomposition results
1288
+ validation_results : Dict, optional
1289
+ Validation results
1290
+ preprocessing_stages : Dict, optional
1291
+ Preprocessing stages
1292
+
1293
+ Returns:
1294
+ --------
1295
+ Dict[str, str] : dictionary with paths to created plots
1296
+ """
1297
+ logger.info("\n" + "="*80)
1298
+ logger.info("STARTING ALL VISUALISATIONS CREATION")
1299
+ logger.info("="*80)
1300
+
1301
+ result_files = {}
1302
+
1303
+ # 1. Summary dashboard
1304
+ if data is not None:
1305
+ logger.info("Creating summary dashboard...")
1306
+ summary_path = self.create_summary_dashboard(data, preprocessing_stages)
1307
+ if summary_path:
1308
+ result_files['summary'] = summary_path
1309
+
1310
+ # 2. Correlation heatmaps
1311
+ if data is not None:
1312
+ logger.info("Creating correlation heatmaps...")
1313
+ main_corr, target_corr = self.create_correlation_heatmap(data)
1314
+ if main_corr:
1315
+ result_files['correlation_main'] = main_corr
1316
+ if target_corr:
1317
+ result_files['correlation_target'] = target_corr
1318
+
1319
+ # 3. Distribution comparison
1320
+ if data is not None and processed_data is not None:
1321
+ logger.info("Creating distribution comparison...")
1322
+ dist_path = self.create_distribution_comparison(data, processed_data)
1323
+ if dist_path:
1324
+ result_files['distribution'] = dist_path
1325
+
1326
+ # 4. Feature importance
1327
+ if feature_importance:
1328
+ logger.info("Creating feature importance plot...")
1329
+ feat_path = self.create_feature_importance_plot(feature_importance)
1330
+ if feat_path:
1331
+ result_files['feature_importance'] = feat_path
1332
+
1333
+ # 5. Time series decomposition
1334
+ if decomposition_result:
1335
+ logger.info("Creating time series decomposition...")
1336
+ decomp_path = self.create_time_series_decomposition_plot(decomposition_result)
1337
+ if decomp_path:
1338
+ result_files['decomposition'] = decomp_path
1339
+
1340
+ # 6. Data quality report
1341
+ if validation_results:
1342
+ logger.info("Creating data quality report...")
1343
+ quality_path = self.create_data_quality_report(validation_results)
1344
+ if quality_path:
1345
+ result_files['quality_report'] = quality_path
1346
+
1347
+ # Save information about all plots
1348
+ self.save_plots_info()
1349
+
1350
+ logger.info("\n" + "="*80)
1351
+ logger.info("VISUALISATIONS SUCCESSFULLY CREATED")
1352
+ logger.info("="*80)
1353
+
1354
+ for plot_name, plot_path in result_files.items():
1355
+ if plot_path:
1356
+ logger.info(f"✓ {plot_name}: {plot_path}")
1357
+
1358
+ return result_files
1359
+
1360
+ def get_all_plots(self) -> Dict:
1361
+ """Get information about all created plots"""
1362
+ return self.plot_files
1363
+
1364
+ def save_plots_info(self, filename: str = "plots_info.json") -> None:
1365
+ """Save plot information to JSON file"""
1366
+ try:
1367
+ plots_info = {
1368
+ 'total_plots': len(self.plot_files),
1369
+ 'plots': self.plot_files,
1370
+ 'directories': {
1371
+ 'correlations': self.correlations_dir,
1372
+ 'distributions': self.distributions_dir,
1373
+ 'features': self.features_dir,
1374
+ 'time_series': self.time_series_dir,
1375
+ 'preprocessing': self.preprocessing_dir,
1376
+ 'summary': self.summary_dir,
1377
+ 'reports': self.reports_dir
1378
+ },
1379
+ 'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
1380
+ 'config': {
1381
+ 'target_column': self.config.target_column,
1382
+ 'results_dir': self.config.results_dir
1383
+ }
1384
+ }
1385
+
1386
+ filepath = os.path.join(self.reports_dir, filename)
1387
+
1388
+ with open(filepath, 'w', encoding='utf-8') as f:
1389
+ json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str)
1390
+
1391
+ logger.info(f"✓ Plot information saved: {filepath}")
1392
+
1393
+ except Exception as e:
1394
+ logger.error(f"✗ Error saving plot information: {e}")
1395
+
1396
+ def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]:
1397
+ """
1398
+ Move existing plots from specified directory to structured folders
1399
+
1400
+ Parameters:
1401
+ -----------
1402
+ source_dir : str, optional
1403
+ Directory with existing plots
1404
+
1405
+ Returns:
1406
+ --------
1407
+ Dict[str, str] : dictionary with information about moved files
1408
+ """
1409
+ if source_dir is None:
1410
+ source_dir = self.plots_dir
1411
+
1412
+ if not os.path.exists(source_dir):
1413
+ logger.warning(f"Source directory doesn't exist: {source_dir}")
1414
+ return {}
1415
+
1416
+ # File to folder mapping
1417
+ file_to_folder_map = {
1418
+ # Time series
1419
+ 'data_split.png': 'time_series',
1420
+ 'stationarity_raskhodvoda.png': 'time_series',
1421
+ 'stationarity_analysis.png': 'time_series',
1422
+ 'temporal_outliers.png': 'time_series',
1423
+
1424
+ # Correlations
1425
+ 'feature_selection_correlation.png': 'correlations',
1426
+
1427
+ # Preprocessing
1428
+ 'missing_values_analysis.png': 'preprocessing',
1429
+ 'outlier_handling_results.png': 'preprocessing',
1430
+ 'outliers_analysis.png': 'preprocessing',
1431
+ 'scaling_results.png': 'preprocessing',
1432
+
1433
+ # Default
1434
+ 'default': 'summary'
1435
+ }
1436
+
1437
+ moved_files = {}
1438
+
1439
+ for filename in os.listdir(source_dir):
1440
+ if filename.endswith('.png'):
1441
+ source_path = os.path.join(source_dir, filename)
1442
+
1443
+ # Determine destination folder
1444
+ target_folder = file_to_folder_map.get(filename, file_to_folder_map['default'])
1445
+ target_dir = os.path.join(self.plots_dir, target_folder)
1446
+
1447
+ # Create destination folder if doesn't exist
1448
+ os.makedirs(target_dir, exist_ok=True)
1449
+
1450
+ # Target path
1451
+ target_path = os.path.join(target_dir, filename)
1452
+
1453
+ try:
1454
+ # Move file
1455
+ os.rename(source_path, target_path)
1456
+ moved_files[filename] = target_path
1457
+ logger.info(f"Moved: {filename} -> {target_folder}/")
1458
+ except Exception as e:
1459
+ logger.error(f"Error moving {filename}: {e}")
1460
+
1461
+ logger.info(f"Moved {len(moved_files)} files")
1462
+ return moved_files