ArabovMK commited on
Commit
d8f69a9
·
1 Parent(s): d9e6371

Update all files

Browse files
.gitignore CHANGED
@@ -1,2 +1,9 @@
1
  .venv/
2
  .venv
 
 
 
 
 
 
 
 
1
  .venv/
2
  .venv
3
+ __pycache__/
4
+ *__pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ .Python
9
+ streamlit_results/
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
 
1
+ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
README.md CHANGED
@@ -1,358 +1,268 @@
1
  ---
2
- title: Pectin Production Predictor
3
- emoji: 🧪
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: docker
7
  pinned: true
8
  app_file: app.py
 
9
  ---
10
 
11
-
12
- # 🧪 Pectin Production Predictor
13
 
14
  <div align="center">
15
 
16
- **Predict Pectin Production Parameters Multi-Model Comparison and Analysis**
17
 
18
- *Machine learning models for predicting pectin yield, galacturonic acid content, molecular weight, and esterification degree*
19
 
20
- [![Hugging Face](https://img.shields.io/badge/🤗-Hugging%20Face%20Space-blue)](https://huggingface.co/spaces/arabovs-ai-lab/PectinProductionModels)
21
- [![License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
22
  [![Streamlit](https://img.shields.io/badge/Interface-Streamlit-FF4B4B)](https://streamlit.io)
 
23
 
24
  </div>
25
 
26
  ## 🌟 Overview
27
 
28
- Pectin Production Predictor provides a comprehensive suite of machine learning models for predicting key parameters in pectin production processes. The application enables researchers and industry professionals to optimize extraction conditions and predict product quality metrics using trained regression models.
29
-
30
- **Supported algorithms:**
31
- - **Gradient Boosting** (Best overall performance)
32
- - **Random Forest**
33
- - **XGBoost**
34
- - **Linear Regression**
35
-
36
- All models are published to the Hugging Face Hub under [`arabovs-ai-lab/PectinProductionModels`](https://huggingface.co/arabovs-ai-lab/PectinProductionModels) for easy access and integration.
37
-
38
- ---
39
-
40
- ## 🚀 Features
41
-
42
- ### 🔮 Prediction Capabilities
43
- - **Single prediction**: Interactive parameter tuning for individual experiments
44
- - **Batch processing**: Upload Excel/CSV files for bulk predictions
45
- - **Multi-model comparison**: Compare predictions across different algorithms
46
- - **Real-time visualization**: Interactive charts and performance metrics
47
-
48
- ### 📊 Analytical Tools
49
- - **Statistical comparison**: Side-by-side model performance analysis
50
- - **Quality metrics**: MAE, MSE, RMSE, R², MAPE, and correlation coefficients
51
- - **Visual analytics**: Distribution plots, comparison charts, and trend analysis
52
- - **Data validation**: Automatic file structure detection and data preprocessing
53
 
54
- ### 🎯 Target Parameters
55
- - **Pectin Yield (%)** - Extraction efficiency
56
- - **Galacturonic Acid (%)** - Pectin purity indicator
57
- - **Molecular Weight (Da)** - Molecular characteristics
58
- - **Esterification Degree (%)** - Functional properties
59
 
60
- ---
61
-
62
- ## 🧪 Input Parameters
63
-
64
- | Parameter | Description | Range | Unit |
65
- |-----------|-------------|-------|------|
66
- | **Sample Type** | Raw material type | 7 variants | - |
67
- | **Time** | Extraction duration | 0-300 | minutes |
68
- | **Temperature** | Extraction temperature | 0-200 | °C |
69
- | **Pressure** | Process pressure | 0-10 | atm |
70
- | **pH** | Acidity level | 0-14 | - |
71
-
72
- ### 📋 Sample Type Reference
73
 
74
- | Sample Code | Description |
75
- |-------------|-------------|
76
- | Абр. | Абрикосовый (Apricot) |
77
- | Рв. | Ревень (Rhubarb) |
78
- | Айв. | Айвы (Quince) |
79
- | Ткв | Тыквенный (Pumpkin) |
80
- | КрП | Корзинка подсолнечника (Sunflower head) |
81
- | ЯП(Ф) | Яблочный пектин Файзобод (Apple pectin Fayzobod) |
82
- | ЯП(М) | Яблочный пектин Муминобод (Apple pectin Muminobod) |
83
 
84
- ---
85
-
86
- ## 📊 Experimental Data Examples
 
 
87
 
88
- ### Sample Input-Output Data
89
 
90
- | Exp | Sample | Time (min) | Temp (°C) | Pressure (atm) | pH | Pectin Yield (%) | Galacturonic Acid (%) | Molecular Weight (Da) | Esterification Degree (%) |
91
- |-----|--------|------------|-----------|----------------|----|------------------|---------------------|----------------------|-------------------------|
92
- | 1 | ЯП(М) | 7 | 120 | 2.08 | 2.0 | 25.864 | 52.706 | 103,773.64 | 71.17 |
93
- | 2 | ЯП(М) | 7 | 120 | 1.74 | 2.08 | 24.830 | 51.645 | 103,098.49 | 70.015 |
94
- | 3 | Абр. | 5 | 130 | 2.09 | 1.74 | 14.755 | 67.550 | 127,235.35 | 82.813 |
95
- | 4 | ЯП(М) | 7 | 120 | 2.05 | 2.0 | 26.353 | 53.804 | 105,994.85 | 65.415 |
96
- | 5 | КрП | 60 | 85 | 1.03 | 2.0 | 19.505 | 66.606 | 145,498.37 | 67.756 |
97
 
98
- ---
 
 
 
 
 
 
 
 
 
 
99
 
100
- ## 📈 Model Performance
 
 
 
 
101
 
102
- ### Comprehensive Model Comparison (Test Set Averages)
103
 
104
- | Model | Avg Test R² | Avg Test MAE | Avg Test RMSE |
105
- |-------|------------:|-------------:|--------------:|
106
- | **Gradient Boosting** | **0.9427** | **868.440** | **1074.277** |
107
- | Random Forest | 0.9259 | 978.007 | 1214.291 |
108
- | XGBoost | 0.9203 | 1074.231 | 1327.170 |
109
- | Extra Trees | 0.9135 | 1060.174 | 1314.689 |
110
- | K-Neighbors | 0.8684 | 1287.513 | 2230.119 |
111
- | MultiLayer Perceptron | 0.8046 | 4253.843 | 5488.065 |
112
- | Linear Regression | 0.6965 | 3730.755 | 4818.582 |
113
- | Ridge Regression | 0.5553 | 3665.310 | 4850.510 |
114
- | SVR | 0.4832 | 6612.236 | 7939.850 |
115
- | Lasso Regression | 0.3846 | 3702.033 | 4828.528 |
116
 
117
- ### Production-Ready Models Selection
 
 
 
 
118
 
119
- | Model | Best For | Training Time | Robustness | Recommendation |
120
- |-------|----------|---------------|------------|----------------|
121
- | **Gradient Boosting** | Overall accuracy | Medium | High | ⭐⭐⭐⭐⭐ **Primary Choice** |
122
- | **Random Forest** | Stability & Speed | Fast | Very High | ⭐⭐⭐⭐ **Secondary Choice** |
123
- | **XGBoost** | Large datasets | Fast | High | ⭐⭐⭐⭐ **Alternative** |
124
- | Linear Regression | Baseline | Very Fast | Medium | ⭐⭐ **Reference** |
125
 
126
- > **Note:** Metrics represent averages across all four target variables. Gradient Boosting demonstrates superior performance across all evaluation criteria.
 
 
 
 
127
 
128
- ---
 
 
 
 
129
 
130
- ## 🛠️ Quick Start
 
 
 
 
131
 
132
- ### Installation with Version Control
133
- ```bash
134
- # Install with specific Streamlit version to avoid compatibility issues
135
- pip install streamlit==1.28.0 pandas numpy plotly scikit-learn joblib huggingface-hub xgboost
136
- ```
137
 
138
- ### Run the Application
139
  ```bash
140
- streamlit run app_pectin.py
141
- ```
 
142
 
143
- ### Troubleshooting Version Issues
144
- If you encounter Streamlit version warnings, ensure you're using the tested version:
145
- ```bash
146
- pip install --force-reinstall streamlit==1.28.0
147
- ```
148
 
149
- ### Programmatic Usage
150
- ```python
151
- from huggingface_hub import hf_hub_download
152
- import joblib
153
-
154
- # Download model artifacts
155
- model_path = hf_hub_download(
156
- repo_id="arabovs-ai-lab/PectinProductionModels",
157
- filename="gradient_boosting_model.pkl", # Best performing model
158
- repo_type="model"
159
- )
160
-
161
- # Load model
162
- model = joblib.load(model_path)
163
-
164
- # Example prediction
165
- input_data = {
166
- 'sample': 'ЯП(М)',
167
- 'time_min': 120,
168
- 'temperature_c': 90,
169
- 'pressure_atm': 1.5,
170
- 'ph': 2.0
171
- }
172
- prediction = model.predict(input_data)
173
  ```
174
 
175
- ---
176
-
177
- ## 📦 Model Architecture
178
-
179
- ### Feature Engineering
180
- - **Sample encoding**: Intelligent categorical encoding of 7 pectin source types
181
- - **Method detection**: Automatic extraction method classification
182
- - **Feature scaling**: Standardized input parameters (Time, Temperature, Pressure, pH)
183
- - **Multi-output regression**: Simultaneous prediction of 4 target variables
184
- - **Cross-validation**: Robust training with 5-fold cross-validation
185
 
186
- ### Model Structure
187
- ```
188
- Input Features → Preprocessing → Multi-output Regression → Target Predictions
189
- ↓ ↓ ↓ ↓
190
- 5 parameters Encoding & Ensemble Models 4 quality metrics
191
- Scaling (GB, RF, XGB, LR)
192
  ```
193
 
194
- ---
195
-
196
- ## 🔧 Usage Examples
197
 
198
- ### Single Prediction
199
  ```python
200
- # Optimal extraction parameters example
201
- optimal_input = {
202
- 'sample': 'ЯП(М)',
203
- 'time_min': 7,
204
- 'temperature_c': 120,
205
- 'pressure_atm': 2.08,
206
- 'ph': 2.0
 
 
 
 
 
207
  }
208
 
209
- # Expected output based on experimental data:
210
- # Pectin Yield: ~25.8%, Galacturonic Acid: ~52.7%
211
- # Molecular Weight: ~103,774 Da, Esterification: ~71.2%
212
- ```
213
 
214
- ### Batch Processing Template
215
- ```csv
216
- sample,time_min,temperature_c,pressure_atm,ph
217
- ЯП(М),7,120,2.08,2.0
218
- Абр.,5,130,2.09,1.74
219
- КрП,60,85,1.03,2.0
220
  ```
221
 
222
- ---
223
-
224
- ## 📊 Performance Analysis
225
-
226
- ### Key Findings from Model Evaluation
227
- 1. **Gradient Boosting** achieves the highest R² score (0.9427) indicating excellent explanatory power
228
- 2. **Tree-based models** (GB, RF, XGBoost) significantly outperform linear models
229
- 3. **Random Forest** provides the best balance of performance and training speed
230
- 4. **Ensemble methods** demonstrate superior robustness to experimental variability
231
 
232
- ### Metric Interpretation
233
- - **R² (0.9427)**: Models explain 94.27% of variance in target variables
234
- - **MAE (868.44)**: Average prediction error across all targets
235
- - **RMSE (1074.28)**: Standard deviation of prediction residuals
 
236
 
237
- ---
238
-
239
- ## 🎯 Application Interface
240
 
241
- ### Main Tabs
242
- 1. **🎯 Single Prediction**: Interactive parameter tuning with real-time predictions
243
- 2. **📁 Batch Processing**: File upload and bulk analysis with validation
244
- 3. **📊 Model Comparison**: Side-by-side performance evaluation
245
- 4. **🔄 Multi-Model Processing**: Apply multiple algorithms simultaneously
246
 
247
- ### Input Validation
248
- - **Range checking**: Automatic validation of physical parameter limits
249
- - **Sample verification**: Validation against supported sample types
250
- - **Data integrity**: Comprehensive file structure validation for batch processing
251
-
252
- ---
253
 
254
- ## 📁 Data Compatibility
 
 
 
255
 
256
- ### Supported Formats
257
- - **Excel** (.xlsx, .xls): Automatic structure detection and sheet selection
258
- - **CSV/TXT**: Multiple encoding support (UTF-8, Windows-1251) and delimiter auto-detection
259
- - **Column mapping**: Automatic Russian ↔ English column name translation
260
 
261
- ### Required Input Format
262
- ```csv
263
- sample,time_min,temperature_c,pressure_atm,ph
264
- ЯП(М),7,120,2.08,2.0
265
- Абр.,5,130,2.09,1.74
 
 
266
  ```
267
 
268
- ### Optional Ground Truth Columns
269
- ```csv
270
- sample,time_min,temperature_c,pressure_atm,ph,pectin_yield,galacturonic_acid,molecular_weight,esterification_degree
271
- ЯП(М),7,120,2.08,2.0,25.864,52.706,103773.64,71.17
 
272
  ```
273
 
274
- ---
275
-
276
- ## 🔬 Scientific Background
277
-
278
- ### Pectin Production Optimization
279
- Pectin is a complex polysaccharide found in plant cell walls, with critical applications as:
280
- - **Gelling agent** in jams, jellies, and confectionery
281
- - **Stabilizer** in dairy products and beverages
282
- - **Pharmaceutical excipient** in drug delivery systems
283
- - **Functional ingredient** in cosmetic formulations
284
-
285
- ### Extraction Parameter Effects
286
- - **Time & Temperature**: Directly impact yield and molecular weight degradation
287
- - **pH**: Critical for galacturonic acid content and esterification degree
288
- - **Pressure**: Influences extraction efficiency and pectin quality
289
- - **Raw Material**: Determines inherent pectin characteristics and optimal conditions
290
-
291
- ---
292
-
293
  ## 🤝 Contributing
294
 
295
- We welcome contributions in the following areas:
296
 
297
- - **New algorithms**: Additional machine learning models (Neural Networks, etc.)
298
- - **Feature engineering**: Advanced preprocessing and feature selection techniques
299
- - **Visualization**: Enhanced analytical dashboards and interactive plots
300
- - **Documentation**: Additional usage examples and case studies
 
 
301
 
302
- ### Development Setup
303
  ```bash
304
- git clone https://huggingface.co/spaces/arabovs-ai-lab/PectinProductionModels
305
- cd PectinProductionModels
306
- pip install -r requirements.txt
307
- streamlit run app_pectin.py
308
- ```
309
 
310
- ### Version Management
311
- To ensure compatibility, maintain the specified package versions:
312
- ```txt
313
- streamlit==1.28.0
314
- scikit-learn>=1.0.0
315
- xgboost>=1.5.0
316
  ```
317
 
318
- ---
319
-
320
- ## 📜 Citation
321
 
322
- If you use this tool in your research, please cite:
323
-
324
- ```bibtex
325
- @misc{pectin_predictor_2024,
326
- title = {Pectin Production Predictor: Machine Learning Models for Pectin Quality Prediction},
327
- author = {Arabovs AI Lab},
328
- year = {2024},
329
- publisher = {Hugging Face},
330
- url = {https://huggingface.co/arabovs-ai-lab/PectinProductionModels}
331
- }
332
- ```
333
 
334
- ---
335
 
336
- ## 📄 License
 
 
 
 
337
 
338
- MIT License - See [LICENSE](LICENSE) file for details.
 
 
 
 
 
339
 
340
- ---
341
 
342
- ## 🏛️ Institutional Support
 
 
 
343
 
344
- This project is maintained by **Arabovs AI Lab** as part of our commitment to advancing applied machine learning in industrial and biotechnological applications.
 
 
 
345
 
346
  ---
347
 
348
  <div align="center">
349
 
350
- **Advancing Biotechnology with Machine Learning**
351
- Brought to you by **Arabovs AI Lab**
352
 
353
- [![Repository](https://img.shields.io/badge/🔗-Model%20Repository-171717)](https://huggingface.co/arabovs-ai-lab/PectinProductionModels)
354
- [![Live Demo](https://img.shields.io/badge/🚀-Live%20Demo-FF4B4B)](https://huggingface.co/spaces/arabovs-ai-lab/PectinProductionModels)
355
 
356
- </div>
 
357
 
358
- *Last updated: November 2025*
 
1
  ---
2
+ title: TimeFlow Pro
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: true
8
  app_file: app.py
9
+ sdk_version: 1.52.2
10
  ---
11
 
12
+ # 📊 TimeFlow Pro
 
13
 
14
  <div align="center">
15
 
16
+ **Intelligent Time Series Data Analysis and Preprocessing Platform**
17
 
18
+ *Advanced pipeline for data preparation and feature engineering*
19
 
20
+ [![Hugging Face](https://img.shields.io/badge/🤗-Hugging%20Face%20Space-blue)](https://huggingface.co/spaces/your-username/timeflow-pro)
 
21
  [![Streamlit](https://img.shields.io/badge/Interface-Streamlit-FF4B4B)](https://streamlit.io)
22
+ [![Python](https://img.shields.io/badge/Python-3.9+-blue)](https://python.org)
23
 
24
  </div>
25
 
26
  ## 🌟 Overview
27
 
28
+ TimeFlow Pro is a comprehensive platform for time series data analysis, preprocessing, and feature engineering. Designed for data scientists and analysts, it provides an intuitive interface for transforming raw time series data into ML-ready datasets with advanced preprocessing capabilities.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ ## 🚀 Key Features
 
 
 
 
31
 
32
+ ### 📈 **Data Analysis & Visualization**
33
+ - **Interactive Data Exploration**: Real-time preview and statistics
34
+ - **Missing Value Analysis**: Smart detection and handling strategies
35
+ - **Outlier Detection**: Multiple methods including IQR, Z-Score, Isolation Forest
36
+ - **Temporal Analysis**: Seasonality detection, trend analysis, decomposition
 
 
 
 
 
 
 
 
37
 
38
+ ### ⚙️ **Advanced Preprocessing Pipeline**
39
+ - **Feature Engineering**: Automatic lag features, rolling statistics, seasonal components
40
+ - **Stationarity Checking**: ADF tests and transformation suggestions
41
+ - **Data Scaling**: Robust, Standard, MinMax, and custom scaling methods
42
+ - **Feature Selection**: Correlation, variance, mutual information, RF importance
 
 
 
 
43
 
44
+ ### 🏗️ **ML-Ready Outputs**
45
+ - **Train/Validation/Test Splits**: Time-based or random splitting
46
+ - **Multiple Export Formats**: CSV, Parquet, Excel, JSON
47
+ - **Model Integration**: Ready-to-use datasets for scikit-learn, XGBoost, LightGBM
48
+ - **Visual Reports**: Comprehensive pipeline execution reports
49
 
50
+ ## 🎮 Quick Start
51
 
52
+ ### 1. **Upload Your Data**
53
+ - Support for CSV, Excel, Parquet formats
54
+ - Automatic date parsing and validation
55
+ - Smart column type detection
 
 
 
56
 
57
+ ### 2. **Configure Pipeline**
58
+ ```python
59
+ # Example configuration
60
+ config = {
61
+ 'target_column': 'sales',
62
+ 'test_size': 0.2,
63
+ 'max_lags': 5,
64
+ 'seasonal_period': 365,
65
+ 'scaling_method': 'robust'
66
+ }
67
+ ```
68
 
69
+ ### 3. **Run Pipeline & Export**
70
+ - Execute full preprocessing pipeline
71
+ - Download processed data
72
+ - Get feature importance reports
73
+ - Export modeling datasets
74
 
75
+ ## 📊 Technical Architecture
76
 
77
+ ### 🔧 **Pipeline Components**
78
+ ```
79
+ Data Loading Validation Missing Handling Outlier Treatment
80
+
81
+ Feature Engineering Stationarity Check Correlation Analysis
82
+
83
+ Data Splitting Scaling Feature Selection Final Validation
84
+ ```
 
 
 
 
85
 
86
+ ### 🏆 **Core Features**
87
+ - **Multi-stage Validation**: Raw, processed, and final data validation
88
+ - **Memory Optimization**: Efficient handling of large datasets
89
+ - **Error Recovery**: Graceful handling of pipeline failures
90
+ - **Reproducible Results**: Configuration saving and logging
91
 
92
+ ## 📚 Use Cases
 
 
 
 
 
93
 
94
+ ### 🏢 **Business Analytics**
95
+ - Sales forecasting and trend analysis
96
+ - Inventory optimization
97
+ - Customer behavior prediction
98
+ - Financial time series analysis
99
 
100
+ ### 🏭 **Industrial Applications**
101
+ - Sensor data preprocessing
102
+ - Predictive maintenance
103
+ - Quality control monitoring
104
+ - Energy consumption forecasting
105
 
106
+ ### 🎓 **Academic Research**
107
+ - Time series modeling experiments
108
+ - Feature engineering research
109
+ - Algorithm comparison studies
110
+ - Educational tool for data science
111
 
112
+ ## 🛠️ Installation
 
 
 
 
113
 
114
+ ### Local Development
115
  ```bash
116
+ # Clone repository
117
+ git clone https://huggingface.co/spaces/your-username/timeflow-pro
118
+ cd timeflow-pro
119
 
120
+ # Install dependencies
121
+ pip install -r requirements.txt
 
 
 
122
 
123
+ # Run application
124
+ streamlit run app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  ```
126
 
127
+ ### Docker Deployment
128
+ ```bash
129
+ # Build Docker image
130
+ docker build -t timeflow-pro .
 
 
 
 
 
 
131
 
132
+ # Run container
133
+ docker run -p 8501:8501 timeflow-pro
 
 
 
 
134
  ```
135
 
136
+ ## 🌐 API Usage Example
 
 
137
 
 
138
  ```python
139
+ from timeflow_pro import TimeFlowPipeline
140
+ import pandas as pd
141
+
142
+ # Load your data
143
+ data = pd.read_csv('your_data.csv')
144
+
145
+ # Configure pipeline
146
+ config = {
147
+ 'target_column': 'target',
148
+ 'test_size': 0.2,
149
+ 'max_lags': 7,
150
+ 'seasonal_period': 30
151
  }
152
 
153
+ # Create and run pipeline
154
+ pipeline = TimeFlowPipeline(config)
155
+ processed_data = pipeline.run(data)
 
156
 
157
+ # Get modeling data
158
+ modeling_data = pipeline.get_modeling_data()
159
+ X_train, y_train = modeling_data['X_train'], modeling_data['y_train']
 
 
 
160
  ```
161
 
162
+ ## 📈 Performance Benchmarks
 
 
 
 
 
 
 
 
163
 
164
+ | Dataset Size | Processing Time | Memory Usage | Features Generated |
165
+ |--------------|----------------|--------------|-------------------|
166
+ | 10K rows | ~5 seconds | <500 MB | 50-100 features |
167
+ | 100K rows | ~30 seconds | <1 GB | 100-200 features |
168
+ | 1M rows | ~5 minutes | <2 GB | 200-500 features |
169
 
170
+ ## 🔧 Configuration Options
 
 
171
 
172
+ ### **Data Processing**
173
+ - `missing_threshold`: Threshold for column removal (0.0-0.5)
174
+ - `outlier_method`: IQR, Z-Score, or Isolation Forest
175
+ - `scaling_method`: Robust, Standard, MinMax, or None
 
176
 
177
+ ### **Feature Engineering**
178
+ - `max_lags`: Maximum lag features (1-20)
179
+ - `seasonal_period`: Seasonal window (7, 30, 90, 365)
180
+ - `rolling_windows`: List of rolling windows [7, 30, 90]
 
 
181
 
182
+ ### **Model Preparation**
183
+ - `feature_selection_method`: Correlation, Variance, RF, Mutual Info
184
+ - `max_features`: Maximum features to select (5-100)
185
+ - `split_method`: Time-based or random splitting
186
 
187
+ ## 📋 Requirements
 
 
 
188
 
189
+ ### **Core Dependencies**
190
+ ```txt
191
+ streamlit>=1.28.0
192
+ pandas>=2.0.0
193
+ numpy>=1.24.0
194
+ plotly>=5.17.0
195
+ scikit-learn>=1.3.0
196
  ```
197
 
198
+ ### **Optional Dependencies**
199
+ ```txt
200
+ xgboost>=2.0.0 # For XGBoost feature importance
201
+ lightgbm>=4.0.0 # For LightGBM integration
202
+ statsmodels>=0.14.0 # For advanced time series analysis
203
  ```
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  ## 🤝 Contributing
206
 
207
+ We welcome contributions! Here's how you can help:
208
 
209
+ ### **Areas for Contribution**
210
+ 1. **New Feature Engineering Methods**
211
+ 2. **Additional Visualization Types**
212
+ 3. **Export Format Support**
213
+ 4. **Performance Optimizations**
214
+ 5. **Documentation Improvements**
215
 
216
+ ### **Development Workflow**
217
  ```bash
218
+ # 1. Fork the repository
219
+ # 2. Create feature branch
220
+ git checkout -b feature/new-feature
 
 
221
 
222
+ # 3. Make changes and test
223
+ # 4. Submit pull request
 
 
 
 
224
  ```
225
 
226
+ ## 📜 License
 
 
227
 
228
+ This project is licensed under the **MIT License** - see the [LICENSE](LICENSE) file for details.
 
 
 
 
 
 
 
 
 
 
229
 
230
+ ## 🙏 Acknowledgments
231
 
232
+ ### **Special Thanks To:**
233
+ - **Streamlit Team** for the amazing framework
234
+ - **Hugging Face** for hosting the Space
235
+ - **Open Source Community** for invaluable libraries
236
+ - **All Contributors** who helped improve TimeFlow Pro
237
 
238
+ ### **Built With:**
239
+ - 🐍 Python
240
+ - 📊 Streamlit
241
+ - 🎨 Plotly
242
+ - 🔧 Scikit-learn
243
+ - 📈 Pandas & NumPy
244
 
245
+ ## 📞 Support & Contact
246
 
247
+ ### **Get Help:**
248
+ - 📧 **Email**: cool.araby@gmail.com
249
+ - 💬 **Issues**: [GitHub Issues](https://github.com/your-username/timeflow-pro/issues)
250
+ - 💡 **Discussions**: [Community Forum](https://github.com/your-username/timeflow-pro/discussions)
251
 
252
+ ### **Stay Updated:**
253
+ - ⭐ **Star** the repository
254
+ - 👁️ **Watch** for releases
255
+ - 🔔 **Enable notifications**
256
 
257
  ---
258
 
259
  <div align="center">
260
 
261
+ **Transform Your Time Series Data with Ease**
 
262
 
263
+ *TimeFlow Pro - Making Data Preparation Simple and Powerful*
 
264
 
265
+ [![Follow on Hugging Face](https://img.shields.io/badge/Follow%20on-🤗%20Hugging%20Face-yellow)](https://huggingface.co/your-username)
266
+ [![GitHub Stars](https://img.shields.io/github/stars/your-username/timeflow-pro?style=social)](https://github.com/your-username/timeflow-pro)
267
 
268
+ </div>
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
config/__init__.py ADDED
File without changes
config/config.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # ENUMERATION CLASSES
3
+ # ============================================
4
+ from dataclasses import asdict, dataclass, field
5
+ from enum import Enum
6
+ import json
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional
10
+ from venv import logger
11
+
12
+
13
+ class DataType(Enum):
14
+ """Data types"""
15
+ NUMERIC = "numeric"
16
+ CATEGORICAL = "categorical"
17
+ TEMPORAL = "temporal"
18
+ TEXT = "text"
19
+
20
+
21
+ class PreprocessingMethod(Enum):
22
+ """Data preprocessing methods"""
23
+ FILL_MEAN = "fill_mean"
24
+ FILL_MEDIAN = "fill_median"
25
+ FILL_INTERPOLATE = "fill_interpolate"
26
+ FILL_KNN = "fill_knn"
27
+ REMOVE = "remove"
28
+ CLIP = "clip"
29
+ WINSORIZE = "winsorize"
30
+ NORMALIZE = "normalize"
31
+ STANDARDIZE = "standardize"
32
+ LOG_TRANSFORM = "log_transform"
33
+ BOX_COX = "box_cox"
34
+ DIFFERENCING = "differencing"
35
+
36
+
37
+ class SeasonalityType(Enum):
38
+ """Seasonality types"""
39
+ DAILY = "daily"
40
+ WEEKLY = "weekly"
41
+ MONTHLY = "monthly"
42
+ QUARTERLY = "quarterly"
43
+ YEARLY = "yearly"
44
+ MULTIPLE = "multiple"
45
+
46
+
47
+ # ============================================
48
+ # CLASS 1: CONFIGURATION
49
+ # ============================================
50
+ @dataclass
51
+ class Config:
52
+ """Experiment configuration for data preprocessing"""
53
+
54
+ # Paths and directories
55
+ data_path: str = 'temp_data.csv'
56
+ results_dir: str = 'data_preprocessing_results'
57
+
58
+ # Temporal parameters
59
+ start_year: int = 1970
60
+ end_year: int = 1990
61
+ freq: str = 'D' # Data frequency: D (daily), H (hourly), M (monthly)
62
+
63
+ # Target variable
64
+ target_column: str = 'raskhodvoda'
65
+
66
+ # Feature parameters
67
+ max_lags: int = 12
68
+ seasonal_period: int = 365
69
+ rolling_windows: List[int] = field(default_factory=lambda: [7, 30, 90, 365])
70
+ expanding_windows: List[int] = field(default_factory=lambda: [30, 90, 365])
71
+
72
+ # Processing parameters
73
+ missing_threshold: float = 0.3 # Threshold for dropping columns with missing values
74
+ outlier_method: str = 'iqr' # Outlier detection method: iqr, zscore, lof
75
+ outlier_alpha: float = 1.5 # IQR multiplier
76
+ outlier_contamination: float = 0.1 # For methods like LOF
77
+
78
+ # Data splitting
79
+ test_size: float = 0.2
80
+ validation_size: float = 0.1
81
+ split_method: str = 'time' # time, random, expanding_window
82
+
83
+ # Scaling
84
+ scaling_method: str = 'robust' # standard, minmax, robust, none
85
+
86
+ # Feature selection
87
+ feature_selection_method: str = 'correlation' # correlation, mutual_info, rf, pca
88
+ max_features: int = 50
89
+
90
+ # Validation
91
+ enable_validation: bool = True
92
+ validation_rules: Dict = field(default_factory=dict)
93
+
94
+ # Visualisation
95
+ save_plots: bool = True
96
+ plot_style: str = 'seaborn'
97
+
98
+ # Performance
99
+ use_multiprocessing: bool = False
100
+ n_jobs: int = -1
101
+ chunk_size: int = 10000
102
+
103
+ # Logging
104
+ log_level: str = 'INFO'
105
+ save_reports: bool = True
106
+
107
+ def __post_init__(self):
108
+ """Post-initialisation for creating directories and setting up logging"""
109
+ self.create_directories()
110
+ self.setup_logging()
111
+
112
+ # Setting default validation rules
113
+ if not self.validation_rules:
114
+ self.validation_rules = {
115
+ 'min_rows': 100,
116
+ 'max_missing_percentage': 30,
117
+ 'min_unique_values': 2,
118
+ 'max_skewness': 3,
119
+ 'max_kurtosis': 10
120
+ }
121
+
122
+ def create_directories(self) -> None:
123
+ """Create directories for preprocessing results"""
124
+ dirs = [
125
+ self.results_dir,
126
+ f'{self.results_dir}/plots',
127
+ f'{self.results_dir}/plots/time_series',
128
+ f'{self.results_dir}/plots/distributions',
129
+ f'{self.results_dir}/plots/correlations',
130
+ f'{self.results_dir}/plots/features',
131
+ f'{self.results_dir}/tables',
132
+ f'{self.results_dir}/processed_data',
133
+ f'{self.results_dir}/models',
134
+ f'{self.results_dir}/reports',
135
+ f'{self.results_dir}/logs',
136
+ f'{self.results_dir}/checkpoints'
137
+ ]
138
+
139
+ for directory in dirs:
140
+ Path(directory).mkdir(parents=True, exist_ok=True)
141
+
142
+ logger.info(f"Directories created in {self.results_dir}")
143
+
144
+ def setup_logging(self) -> None:
145
+ """Configure logging"""
146
+ log_level = getattr(logging, self.log_level.upper())
147
+ logger.setLevel(log_level)
148
+
149
+ def to_dict(self) -> Dict:
150
+ """Convert configuration to dictionary"""
151
+ return asdict(self)
152
+
153
+ def save(self, path: Optional[str] = None) -> None:
154
+ """Save configuration to file"""
155
+ if path is None:
156
+ path = f'{self.results_dir}/config.json'
157
+
158
+ with open(path, 'w', encoding='utf-8') as f:
159
+ json.dump(self.to_dict(), f, indent=4, ensure_ascii=False)
160
+
161
+ logger.info(f"Configuration saved to {path}")
162
+
163
+ @classmethod
164
+ def load(cls, path: str) -> 'Config':
165
+ """Load configuration from file"""
166
+ with open(path, 'r', encoding='utf-8') as f:
167
+ config_dict = json.load(f)
168
+
169
+ return cls(**config_dict)
config/default_config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data_path": "temp_data.csv",
3
+ "results_dir": "results",
4
+
5
+ "start_year": 1970,
6
+ "end_year": 1990,
7
+ "freq": "D",
8
+
9
+ "target_column": "raskhodvoda",
10
+
11
+ "max_lags": 12,
12
+ "seasonal_period": 365,
13
+ "rolling_windows": [7, 30, 90, 365],
14
+ "expanding_windows": [30, 90, 365],
15
+
16
+ "missing_threshold": 0.3,
17
+ "outlier_method": "iqr",
18
+ "outlier_alpha": 1.5,
19
+ "outlier_contamination": 0.1,
20
+
21
+ "test_size": 0.2,
22
+ "validation_size": 0.1,
23
+ "split_method": "time",
24
+
25
+ "scaling_method": "robust",
26
+
27
+ "feature_selection_method": "correlation",
28
+ "max_features": 50,
29
+
30
+ "enable_validation": true,
31
+ "validation_rules": {
32
+ "min_rows": 100,
33
+ "max_missing_percentage": 30,
34
+ "min_unique_values": 2,
35
+ "max_skewness": 3,
36
+ "max_kurtosis": 10,
37
+ "min_variance": 0.001,
38
+ "max_constant_columns": 0
39
+ },
40
+
41
+ "save_plots": true,
42
+ "plot_style": "seaborn-whitegrid",
43
+ "plot_dpi": 300,
44
+ "plot_format": "png",
45
+
46
+ "use_multiprocessing": false,
47
+ "n_jobs": -1,
48
+ "chunk_size": 10000,
49
+ "memory_limit_gb": 4,
50
+
51
+ "log_level": "INFO",
52
+ "save_reports": true,
53
+ "report_format": "json",
54
+
55
+ "decomposition_method": "stl",
56
+ "stationarity_tests": ["adf", "kpss"],
57
+ "correlation_threshold": 0.85,
58
+ "vif_threshold": 10,
59
+
60
+ "random_seed": 42,
61
+ "enable_profiling": false,
62
+ "save_intermediate": true,
63
+
64
+ "streamlit_settings": {
65
+ "theme": "light",
66
+ "sidebar_state": "expanded",
67
+ "page_title": "Time Series Preprocessing",
68
+ "page_icon": "📊",
69
+ "layout": "wide"
70
+ },
71
+
72
+ "export_options": {
73
+ "csv": true,
74
+ "parquet": false,
75
+ "excel": false,
76
+ "pickle": true
77
+ }
78
+ }
config/settings.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ General project settings: visualisation, paths, constants
3
+ """
4
+
5
+ import warnings
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional
10
+ import yaml
11
+ import json
12
+ import os
13
+
14
+ # ============================================================================
15
+ # PATHS AND DIRECTORIES
16
+ # ============================================================================
17
+
18
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
19
+ DATA_DIR = PROJECT_ROOT / "data"
20
+ RAW_DATA_DIR = DATA_DIR / "raw"
21
+ PROCESSED_DATA_DIR = DATA_DIR / "processed"
22
+ EXTERNAL_DATA_DIR = DATA_DIR / "external"
23
+
24
+ RESULTS_DIR = PROJECT_ROOT / "results"
25
+ PLOTS_DIR = RESULTS_DIR / "plots"
26
+ MODELS_DIR = RESULTS_DIR / "models"
27
+ REPORTS_DIR = RESULTS_DIR / "reports"
28
+ LOGS_DIR = RESULTS_DIR / "logs"
29
+
30
+ CONFIGS_DIR = PROJECT_ROOT / "configs"
31
+ NOTEBOOKS_DIR = PROJECT_ROOT / "notebooks"
32
+ TESTS_DIR = PROJECT_ROOT / "tests"
33
+
34
+ # Create directories on import
35
+ for directory in [RAW_DATA_DIR, PROCESSED_DATA_DIR, EXTERNAL_DATA_DIR,
36
+ PLOTS_DIR, MODELS_DIR, REPORTS_DIR, LOGS_DIR]:
37
+ directory.mkdir(parents=True, exist_ok=True)
38
+
39
+ # ============================================================================
40
+ # VISUALISATION SETTINGS
41
+ # ============================================================================
42
+
43
+ def setup_visualization(
44
+ style: str = "seaborn-whitegrid",
45
+ palette: str = "husl",
46
+ context: str = "notebook",
47
+ font_scale: float = 1.0,
48
+ dpi: int = 150,
49
+ figsize: tuple = (12, 6),
50
+ **kwargs
51
+ ):
52
+ """
53
+ Configure visualisation parameters for matplotlib and seaborn
54
+
55
+ Parameters:
56
+ -----------
57
+ style : str
58
+ Matplotlib style: 'seaborn-whitegrid', 'ggplot', 'bmh', 'dark_background'
59
+ palette : str
60
+ Seaborn palette: 'husl', 'Set2', 'viridis', 'mako'
61
+ context : str
62
+ Seaborn context: 'paper', 'notebook', 'talk', 'poster'
63
+ font_scale : float
64
+ Font scale
65
+ dpi : int
66
+ Plot resolution
67
+ figsize : tuple
68
+ Default figure size
69
+ """
70
+ # Ignore warnings
71
+ warnings.filterwarnings('ignore')
72
+
73
+ # Matplotlib settings
74
+ plt.style.use(style)
75
+
76
+ # RC parameters
77
+ rc_params = {
78
+ 'font.size': 10,
79
+ 'figure.figsize': figsize,
80
+ 'figure.dpi': dpi,
81
+ 'savefig.dpi': 300,
82
+ 'savefig.bbox': 'tight',
83
+ 'savefig.format': 'png',
84
+ 'axes.titlesize': 12,
85
+ 'axes.labelsize': 10,
86
+ 'xtick.labelsize': 9,
87
+ 'ytick.labelsize': 9,
88
+ 'legend.fontsize': 9,
89
+ 'font.family': ['DejaVu Sans', 'Arial', 'sans-serif'],
90
+ 'figure.titlesize': 14,
91
+ 'axes.grid': True,
92
+ 'grid.alpha': 0.3,
93
+ 'lines.linewidth': 1.5,
94
+ 'lines.markersize': 6,
95
+ 'patch.edgecolor': 'black',
96
+ 'patch.force_edgecolor': True,
97
+ 'xtick.top': False,
98
+ 'ytick.right': False,
99
+ 'axes.spines.top': False,
100
+ 'axes.spines.right': False
101
+ }
102
+
103
+ # Update additional parameters
104
+ rc_params.update(kwargs)
105
+ plt.rcParams.update(rc_params)
106
+
107
+ # Seaborn settings
108
+ sns.set_style(style.replace('seaborn-', ''))
109
+ sns.set_palette(palette)
110
+ sns.set_context(context, font_scale=font_scale)
111
+
112
+ print(f"✓ Visualisation settings applied: style={style}, palette={palette}")
113
+
114
+
115
+ def get_color_palette(name: str = "husl", n_colors: int = 8) -> list:
116
+ """
117
+ Get colour palette
118
+
119
+ Parameters:
120
+ -----------
121
+ name : str
122
+ Palette name
123
+ n_colors : int
124
+ Number of colours
125
+
126
+ Returns:
127
+ --------
128
+ list
129
+ List of colours in HEX format
130
+ """
131
+ palette_map = {
132
+ "husl": sns.color_palette("husl", n_colors),
133
+ "Set2": sns.color_palette("Set2", n_colors),
134
+ "Set3": sns.color_palette("Set3", n_colors),
135
+ "viridis": sns.color_palette("viridis", n_colors),
136
+ "plasma": sns.color_palette("plasma", n_colors),
137
+ "coolwarm": sns.color_palette("coolwarm", n_colors),
138
+ "RdYlBu": sns.color_palette("RdYlBu", n_colors),
139
+ "Spectral": sns.color_palette("Spectral", n_colors),
140
+ "tab10": sns.color_palette("tab10", n_colors),
141
+ "tab20": sns.color_palette("tab20", n_colors),
142
+ }
143
+
144
+ palette = palette_map.get(name, sns.color_palette("husl", n_colors))
145
+ return [f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
146
+ for r, g, b in palette]
147
+
148
+
149
+ # ============================================================================
150
+ # CONSTANTS
151
+ # ============================================================================
152
+
153
+ # Data types
154
+ DATETIME_FORMATS = [
155
+ "%Y-%m-%d", "%Y/%m/%d", "%d.%m.%Y", "%d/%m/%Y",
156
+ "%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S",
157
+ "%d.%m.%Y %H:%M:%S", "%d/%m/%Y %H:%M:%S"
158
+ ]
159
+
160
+ # Metrics
161
+ METRICS = {
162
+ "regression": ["mse", "rmse", "mae", "mape", "r2", "explained_variance"],
163
+ "classification": ["accuracy", "precision", "recall", "f1", "roc_auc"]
164
+ }
165
+
166
+ # Statistical constants
167
+ STATS_CONSTANTS = {
168
+ "confidence_levels": [0.9, 0.95, 0.99],
169
+ "z_scores": {0.9: 1.645, 0.95: 1.96, 0.99: 2.576},
170
+ "outlier_multipliers": {"mild": 1.5, "extreme": 3.0}
171
+ }
172
+
173
+ # Time series parameters
174
+ TIME_SERIES_CONSTANTS = {
175
+ "frequencies": {
176
+ "H": "hourly",
177
+ "D": "daily",
178
+ "W": "weekly",
179
+ "M": "monthly",
180
+ "Q": "quarterly",
181
+ "Y": "yearly"
182
+ },
183
+ "seasonal_periods": {
184
+ "hourly": 24,
185
+ "daily": 7,
186
+ "weekly": 52,
187
+ "monthly": 12,
188
+ "quarterly": 4,
189
+ "yearly": 1
190
+ }
191
+ }
192
+
193
+ # ============================================================================
194
+ # CONFIGURATION UTILITIES
195
+ # ============================================================================
196
+
197
+ def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
198
+ """
199
+ Load configuration from file
200
+
201
+ Parameters:
202
+ -----------
203
+ config_path : str, optional
204
+ Path to configuration file
205
+
206
+ Returns:
207
+ --------
208
+ Dict[str, Any]
209
+ Configuration dictionary
210
+ """
211
+ if config_path is None:
212
+ config_path = CONFIGS_DIR / "default_config.json"
213
+
214
+ config_path = Path(config_path)
215
+
216
+ if not config_path.exists():
217
+ print(f"⚠ Configuration file not found: {config_path}")
218
+ return {}
219
+
220
+ # Determine file format
221
+ if config_path.suffix.lower() in ['.json']:
222
+ with open(config_path, 'r', encoding='utf-8') as f:
223
+ config = json.load(f)
224
+ elif config_path.suffix.lower() in ['.yaml', '.yml']:
225
+ with open(config_path, 'r', encoding='utf-8') as f:
226
+ config = yaml.safe_load(f)
227
+ else:
228
+ raise ValueError(f"Unsupported file format: {config_path.suffix}")
229
+
230
+ print(f"✓ Configuration loaded from: {config_path}")
231
+ return config
232
+
233
+
234
+ def save_config(config: Dict[str, Any], config_path: str) -> None:
235
+ """
236
+ Save configuration to file
237
+
238
+ Parameters:
239
+ -----------
240
+ config : Dict[str, Any]
241
+ Configuration to save
242
+ config_path : str
243
+ Save path
244
+ """
245
+ config_path = Path(config_path)
246
+ config_path.parent.mkdir(parents=True, exist_ok=True)
247
+
248
+ # Determine format
249
+ if config_path.suffix.lower() in ['.json']:
250
+ with open(config_path, 'w', encoding='utf-8') as f:
251
+ json.dump(config, f, indent=2, ensure_ascii=False)
252
+ elif config_path.suffix.lower() in ['.yaml', '.yml']:
253
+ with open(config_path, 'w', encoding='utf-8') as f:
254
+ yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
255
+ else:
256
+ raise ValueError(f"Unsupported file format: {config_path.suffix}")
257
+
258
+ print(f"✓ Configuration saved to: {config_path}")
259
+
260
+
261
+ def merge_configs(base_config: Dict[str, Any],
262
+ override_config: Dict[str, Any]) -> Dict[str, Any]:
263
+ """
264
+ Recursive configuration merging
265
+
266
+ Parameters:
267
+ -----------
268
+ base_config : Dict[str, Any]
269
+ Base configuration
270
+ override_config : Dict[str, Any]
271
+ Override configuration
272
+
273
+ Returns:
274
+ --------
275
+ Dict[str, Any]
276
+ Merged configuration
277
+ """
278
+ result = base_config.copy()
279
+
280
+ for key, value in override_config.items():
281
+ if (key in result and isinstance(result[key], dict)
282
+ and isinstance(value, dict)):
283
+ result[key] = merge_configs(result[key], value)
284
+ else:
285
+ result[key] = value
286
+
287
+ return result
288
+
289
+
290
+ # ============================================================================
291
+ # ENVIRONMENT SETUP
292
+ # ============================================================================
293
+
294
+ def setup_environment(
295
+ log_level: str = "INFO",
296
+ random_seed: int = 42,
297
+ enable_warnings: bool = False,
298
+ memory_limit_gb: Optional[int] = None
299
+ ) -> None:
300
+ """
301
+ Set up environment for reproducibility
302
+
303
+ Parameters:
304
+ -----------
305
+ log_level : str
306
+ Logging level
307
+ random_seed : int
308
+ Seed for random generators
309
+ enable_warnings : bool
310
+ Enable warnings
311
+ memory_limit_gb : int, optional
312
+ Memory limit in GB
313
+ """
314
+ import numpy as np
315
+ import random
316
+ import torch
317
+ import tensorflow as tf
318
+
319
+ # Set seeds
320
+ np.random.seed(random_seed)
321
+ random.seed(random_seed)
322
+
323
+ try:
324
+ torch.manual_seed(random_seed)
325
+ except:
326
+ pass
327
+
328
+ try:
329
+ tf.random.set_seed(random_seed)
330
+ except:
331
+ pass
332
+
333
+ # Configure warnings
334
+ if enable_warnings:
335
+ warnings.filterwarnings('default')
336
+ else:
337
+ warnings.filterwarnings('ignore')
338
+
339
+ # Memory limit (if specified)
340
+ if memory_limit_gb:
341
+ import resource
342
+ soft, hard = resource.getrlimit(resource.RLIMIT_AS)
343
+ memory_limit = memory_limit_gb * 1024**3 # GB to bytes
344
+ resource.setrlimit(resource.RLIMIT_AS, (memory_limit, hard))
345
+ print(f"✓ Memory limit set: {memory_limit_gb} GB")
346
+
347
+ print(f"✓ Environment configured. Random seed: {random_seed}")
348
+
349
+
350
+ # ============================================================================
351
+ # AUTOMATIC SETUP ON IMPORT
352
+ # ============================================================================
353
+
354
+ # Automatically apply visualisation settings
355
+ setup_visualization()
356
+
357
+ # Export useful variables
358
+ __all__ = [
359
+ 'setup_visualization',
360
+ 'get_color_palette',
361
+ 'load_config',
362
+ 'save_config',
363
+ 'merge_configs',
364
+ 'setup_environment',
365
+ 'PROJECT_ROOT',
366
+ 'DATA_DIR',
367
+ 'RAW_DATA_DIR',
368
+ 'PROCESSED_DATA_DIR',
369
+ 'RESULTS_DIR',
370
+ 'PLOTS_DIR',
371
+ 'DATETIME_FORMATS',
372
+ 'METRICS',
373
+ 'STATS_CONSTANTS',
374
+ 'TIME_SERIES_CONSTANTS'
375
+ ]
correlations/__init__.py ADDED
File without changes
correlations/correlation_analyzer.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 8: CORRELATION AND MULTICOLLINEARITY ANALYSIS
3
+ # ============================================
4
+ import os
5
+ import traceback
6
+ from typing import Any, Dict, List, Optional
7
+ from venv import logger
8
+
9
+ from config.config import Config
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ class CorrelationAnalyzer:
14
+ """Class for comprehensive correlation and multicollinearity analysis"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise the analyser
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.correlation_matrices = {}
27
+ self.high_correlation_pairs = {}
28
+ self.multicollinearity_info = {}
29
+ self.vif_scores = {}
30
+
31
+ def analyze(
32
+ self,
33
+ data: pd.DataFrame,
34
+ target_col: Optional[str] = None,
35
+ threshold: float = 0.8,
36
+ detailed: bool = True,
37
+ **kwargs
38
+ ) -> pd.DataFrame:
39
+ """
40
+ Analyse correlations in the data
41
+
42
+ Parameters:
43
+ -----------
44
+ data : pd.DataFrame
45
+ Input data
46
+ target_col : str, optional
47
+ Target variable
48
+ threshold : float
49
+ Threshold for identifying high correlations
50
+ detailed : bool
51
+ Whether to perform detailed analysis
52
+ **kwargs : dict
53
+ Additional parameters
54
+
55
+ Returns:
56
+ --------
57
+ pd.DataFrame
58
+ Correlation matrix
59
+ """
60
+ logger.info("\n" + "="*80)
61
+ logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS")
62
+ logger.info("="*80)
63
+
64
+ target_col = target_col or self.config.target_column
65
+
66
+ try:
67
+ # 1. Calculate correlation matrix
68
+ corr_matrix = self._compute_correlations(data, target_col)
69
+
70
+ if corr_matrix.empty:
71
+ logger.warning("Correlation matrix is empty")
72
+ return pd.DataFrame()
73
+
74
+ # 2. Identify high correlations
75
+ high_correlations = self._detect_high_correlations(corr_matrix, threshold)
76
+ self.high_correlation_pairs['pearson'] = high_correlations
77
+
78
+ # 3. Analyse correlations with target variable
79
+ target_correlations = []
80
+ if target_col in corr_matrix.columns:
81
+ target_correlations = self._get_target_correlations(corr_matrix, target_col)
82
+
83
+ # 4. Analyse multicollinearity (VIF)
84
+ vif_results = self._compute_vif_scores(data)
85
+
86
+ # 5. Detailed analysis if required
87
+ if detailed:
88
+ self._detailed_correlation_analysis(data, corr_matrix, target_col)
89
+
90
+ # 6. Visualisation
91
+ if self.config.save_plots:
92
+ self._plot_correlation_analysis(data, corr_matrix, target_col, high_correlations, vif_results)
93
+
94
+ # 7. Output results
95
+ self._log_analysis_results(corr_matrix, high_correlations, target_correlations, vif_results)
96
+
97
+ return corr_matrix
98
+
99
+ except Exception as e:
100
+ logger.error(f"Error in correlation analysis: {e}")
101
+ logger.error(traceback.format_exc())
102
+ return pd.DataFrame()
103
+
104
+ def _compute_correlations(
105
+ self,
106
+ data: pd.DataFrame,
107
+ target_col: str
108
+ ) -> pd.DataFrame:
109
+ """Calculate correlation matrix"""
110
+ logger.info("Calculating correlation matrix...")
111
+
112
+ # Select only numeric columns
113
+ numeric_data = data.select_dtypes(include=[np.number])
114
+
115
+ # Remove constant columns
116
+ numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
117
+
118
+ if numeric_data.shape[1] < 2:
119
+ logger.warning("Insufficient numeric features for analysis")
120
+ return pd.DataFrame()
121
+
122
+ # Remove missing values
123
+ numeric_data_clean = numeric_data.dropna()
124
+
125
+ if len(numeric_data_clean) < 10:
126
+ logger.warning("Insufficient data after cleaning")
127
+ return pd.DataFrame()
128
+
129
+ # Calculate Pearson correlation
130
+ try:
131
+ corr_matrix = numeric_data_clean.corr(method='pearson')
132
+ self.correlation_matrices['pearson'] = corr_matrix
133
+ logger.info(f"✓ Correlation matrix calculated: {corr_matrix.shape}")
134
+ return corr_matrix
135
+ except Exception as e:
136
+ logger.error(f"Error calculating correlation: {e}")
137
+ return pd.DataFrame()
138
+
139
+ def _detect_high_correlations(
140
+ self,
141
+ corr_matrix: pd.DataFrame,
142
+ threshold: float = 0.8
143
+ ) -> List[Dict[str, Any]]:
144
+ """Detect high correlations"""
145
+ high_correlations = []
146
+
147
+ if corr_matrix.empty:
148
+ return high_correlations
149
+
150
+ # Use upper triangle of matrix
151
+ upper_triangle = corr_matrix.where(
152
+ np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
153
+ )
154
+
155
+ # Find pairs with correlation above threshold
156
+ for col in upper_triangle.columns:
157
+ if col in upper_triangle:
158
+ high_corr_series = upper_triangle[col][abs(upper_triangle[col]) > threshold]
159
+
160
+ for row_idx, correlation in high_corr_series.items():
161
+ if not pd.isna(correlation):
162
+ high_correlations.append({
163
+ 'feature1': row_idx,
164
+ 'feature2': col,
165
+ 'correlation': float(correlation),
166
+ 'abs_correlation': abs(float(correlation))
167
+ })
168
+
169
+ # Sort by absolute correlation value
170
+ high_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
171
+
172
+ logger.info(f"High correlations detected (> {threshold}): {len(high_correlations)}")
173
+ return high_correlations
174
+
175
+ def _get_target_correlations(
176
+ self,
177
+ corr_matrix: pd.DataFrame,
178
+ target_col: str
179
+ ) -> List[Dict[str, Any]]:
180
+ """Get correlations with target variable"""
181
+ target_correlations = []
182
+
183
+ if target_col not in corr_matrix.columns:
184
+ return target_correlations
185
+
186
+ # Extract correlations with target variable
187
+ target_corr_series = corr_matrix[target_col]
188
+
189
+ for feature, correlation in target_corr_series.items():
190
+ if feature != target_col and not pd.isna(correlation):
191
+ target_correlations.append({
192
+ 'feature': feature,
193
+ 'correlation': float(correlation),
194
+ 'abs_correlation': abs(float(correlation)),
195
+ 'direction': 'positive' if correlation > 0 else 'negative'
196
+ })
197
+
198
+ # Sort by absolute value
199
+ target_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
200
+
201
+ logger.info(f"Correlations with target variable calculated: {len(target_correlations)}")
202
+ return target_correlations
203
+
204
+ def _compute_vif_scores(self, data: pd.DataFrame) -> Dict[str, Any]:
205
+ """Calculate VIF (Variance Inflation Factor)"""
206
+ logger.info("Analysing multicollinearity (VIF)...")
207
+
208
+ vif_results = {
209
+ 'scores': {},
210
+ 'issues': [],
211
+ 'summary': {
212
+ 'critical': 0,
213
+ 'high': 0,
214
+ 'medium': 0,
215
+ 'low': 0
216
+ }
217
+ }
218
+
219
+ try:
220
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
221
+ import statsmodels.api as sm
222
+
223
+ # Prepare data
224
+ numeric_data = data.select_dtypes(include=[np.number])
225
+ numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
226
+
227
+ # Remove missing and infinite values
228
+ clean_data = numeric_data.replace([np.inf, -np.inf], np.nan).dropna()
229
+
230
+ if clean_data.shape[0] < 10 or clean_data.shape[1] < 2:
231
+ logger.warning("Insufficient data for VIF analysis")
232
+ return vif_results
233
+
234
+ # Add constant
235
+ X = sm.add_constant(clean_data, has_constant='add')
236
+
237
+ # Calculate VIF for each feature
238
+ vif_scores = {}
239
+ for i, column in enumerate(X.columns):
240
+ if column == 'const':
241
+ continue
242
+
243
+ try:
244
+ vif = variance_inflation_factor(X.values, i)
245
+
246
+ # Handle extreme values
247
+ if np.isinf(vif) or vif > 1e6:
248
+ vif = 1e6
249
+
250
+ vif_scores[column] = float(vif)
251
+
252
+ # Classify by severity
253
+ if vif > 100:
254
+ vif_results['summary']['critical'] += 1
255
+ vif_results['issues'].append({
256
+ 'feature': column,
257
+ 'vif': float(vif),
258
+ 'severity': 'critical',
259
+ 'recommendation': 'Remove feature'
260
+ })
261
+ elif vif > 10:
262
+ vif_results['summary']['high'] += 1
263
+ vif_results['issues'].append({
264
+ 'feature': column,
265
+ 'vif': float(vif),
266
+ 'severity': 'high',
267
+ 'recommendation': 'Consider removal'
268
+ })
269
+ elif vif > 5:
270
+ vif_results['summary']['medium'] += 1
271
+ else:
272
+ vif_results['summary']['low'] += 1
273
+
274
+ except Exception as e:
275
+ logger.warning(f"VIF error for {column}: {e}")
276
+ vif_scores[column] = np.nan
277
+
278
+ vif_results['scores'] = vif_scores
279
+ self.vif_scores = vif_scores
280
+
281
+ logger.info(f"✓ VIF analysis completed. Critical features: {vif_results['summary']['critical']}")
282
+
283
+ except ImportError:
284
+ logger.warning("statsmodels not installed, skipping VIF analysis")
285
+ except Exception as e:
286
+ logger.error(f"VIF analysis error: {e}")
287
+
288
+ return vif_results
289
+
290
+ def _detailed_correlation_analysis(
291
+ self,
292
+ data: pd.DataFrame,
293
+ corr_matrix: pd.DataFrame,
294
+ target_col: str
295
+ ) -> None:
296
+ """Detailed correlation analysis"""
297
+ # Analyse correlation clusters
298
+ if not corr_matrix.empty and corr_matrix.shape[0] > 3:
299
+ try:
300
+ # Use clustering to group correlated features
301
+ from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
302
+ from scipy.spatial.distance import squareform
303
+
304
+ # Convert correlations to distances
305
+ distance_matrix = 1 - abs(corr_matrix)
306
+ np.fill_diagonal(distance_matrix.values, 0)
307
+
308
+ # Clustering
309
+ condensed_dist = squareform(distance_matrix)
310
+ Z = linkage(condensed_dist, method='average')
311
+
312
+ # Determine clusters
313
+ clusters = fcluster(Z, t=0.5, criterion='distance')
314
+
315
+ # Group features by cluster
316
+ feature_clusters = {}
317
+ for idx, cluster_id in enumerate(clusters):
318
+ feature = corr_matrix.columns[idx]
319
+ if cluster_id not in feature_clusters:
320
+ feature_clusters[cluster_id] = []
321
+ feature_clusters[cluster_id].append(feature)
322
+
323
+ # Save cluster information
324
+ self.multicollinearity_info['correlation_clusters'] = feature_clusters
325
+ logger.info(f"Correlated feature clusters detected: {len(feature_clusters)}")
326
+
327
+ except Exception as e:
328
+ logger.debug(f"Cluster analysis failed: {e}")
329
+
330
+ def _plot_correlation_analysis(
331
+ self,
332
+ data: pd.DataFrame,
333
+ corr_matrix: pd.DataFrame,
334
+ target_col: str,
335
+ high_correlations: List[Dict[str, Any]],
336
+ vif_results: Dict[str, Any]
337
+ ) -> None:
338
+ """Visualise correlation analysis"""
339
+ try:
340
+ import matplotlib.pyplot as plt
341
+ import seaborn as sns
342
+ from matplotlib import rcParams
343
+
344
+ # Style settings
345
+ plt.style.use('seaborn-v0_8-darkgrid')
346
+ rcParams.update({
347
+ 'figure.figsize': (12, 8),
348
+ 'font.size': 10,
349
+ 'axes.titlesize': 14,
350
+ 'axes.labelsize': 12
351
+ })
352
+
353
+ # Create directory
354
+ plots_dir = os.path.join(self.config.results_dir, 'plots', 'correlations')
355
+ os.makedirs(plots_dir, exist_ok=True)
356
+
357
+ # 1. Correlation matrix heatmap
358
+ if not corr_matrix.empty and corr_matrix.shape[0] > 1:
359
+ fig, ax = plt.subplots(figsize=(14, 12))
360
+
361
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
362
+ sns.heatmap(
363
+ corr_matrix,
364
+ mask=mask,
365
+ annot=True,
366
+ fmt='.2f',
367
+ cmap='coolwarm',
368
+ center=0,
369
+ square=True,
370
+ linewidths=0.5,
371
+ cbar_kws={"shrink": 0.8},
372
+ ax=ax
373
+ )
374
+ ax.set_title('Correlation Matrix (Pearson)', fontweight='bold')
375
+ plt.tight_layout()
376
+ plt.savefig(os.path.join(plots_dir, 'correlation_matrix.png'),
377
+ dpi=150, bbox_inches='tight')
378
+ plt.close()
379
+
380
+ # 2. Target variable correlations
381
+ if target_col in corr_matrix.columns:
382
+ target_corrs = corr_matrix[target_col].drop(target_col, errors='ignore')
383
+ if not target_corrs.empty:
384
+ fig, ax = plt.subplots(figsize=(10, 8))
385
+
386
+ top_corrs = target_corrs.abs().sort_values(ascending=True).tail(20)
387
+ colors = ['red' if target_corrs[feat] < 0 else 'blue'
388
+ for feat in top_corrs.index]
389
+
390
+ ax.barh(range(len(top_corrs)), top_corrs.values, color=colors)
391
+ ax.set_yticks(range(len(top_corrs)))
392
+ ax.set_yticklabels(top_corrs.index)
393
+ ax.set_xlabel('Absolute correlation')
394
+ ax.set_title(f'Top-20 correlations with {target_col}', fontweight='bold')
395
+ ax.grid(True, alpha=0.3, axis='x')
396
+
397
+ plt.tight_layout()
398
+ plt.savefig(os.path.join(plots_dir, 'target_correlations.png'),
399
+ dpi=150, bbox_inches='tight')
400
+ plt.close()
401
+
402
+ # 3. VIF scores plot
403
+ if vif_results['scores']:
404
+ valid_scores = {k: v for k, v in vif_results['scores'].items()
405
+ if not pd.isna(v)}
406
+ if valid_scores:
407
+ fig, ax = plt.subplots(figsize=(12, 8))
408
+
409
+ sorted_scores = dict(sorted(valid_scores.items(),
410
+ key=lambda x: x[1],
411
+ reverse=True)[:25])
412
+
413
+ colors = []
414
+ for vif in sorted_scores.values():
415
+ if vif > 100:
416
+ colors.append('red')
417
+ elif vif > 10:
418
+ colors.append('orange')
419
+ elif vif > 5:
420
+ colors.append('yellow')
421
+ else:
422
+ colors.append('green')
423
+
424
+ bars = ax.barh(list(sorted_scores.keys()),
425
+ list(sorted_scores.values()),
426
+ color=colors, edgecolor='black')
427
+
428
+ ax.set_xlabel('VIF Score')
429
+ ax.set_title('VIF Scores (multicollinearity)', fontweight='bold')
430
+ ax.axvline(x=5, color='yellow', linestyle='--', alpha=0.7)
431
+ ax.axvline(x=10, color='orange', linestyle='--', alpha=0.7)
432
+ ax.axvline(x=100, color='red', linestyle='--', alpha=0.7)
433
+ ax.grid(True, alpha=0.3, axis='x')
434
+
435
+ plt.tight_layout()
436
+ plt.savefig(os.path.join(plots_dir, 'vif_scores.png'),
437
+ dpi=150, bbox_inches='tight')
438
+ plt.close()
439
+
440
+ # 4. High correlations plot
441
+ if high_correlations:
442
+ fig, ax = plt.subplots(figsize=(12, 8))
443
+
444
+ # Limit number for display
445
+ display_corrs = high_correlations[:15]
446
+
447
+ # Create labels for feature pairs
448
+ labels = [f"{corr['feature1']} ↔ {corr['feature2']}"
449
+ for corr in display_corrs]
450
+ values = [corr['correlation'] for corr in display_corrs]
451
+ colors = ['red' if v < 0 else 'blue' for v in values]
452
+
453
+ y_pos = np.arange(len(display_corrs))
454
+ ax.barh(y_pos, values, color=colors)
455
+ ax.set_yticks(y_pos)
456
+ ax.set_yticklabels(labels, fontsize=9)
457
+ ax.invert_yaxis()
458
+ ax.set_xlabel('Correlation')
459
+ ax.set_title('High correlations (> 0.8)', fontweight='bold')
460
+ ax.grid(True, alpha=0.3, axis='x')
461
+
462
+ plt.tight_layout()
463
+ plt.savefig(os.path.join(plots_dir, 'high_correlations.png'),
464
+ dpi=150, bbox_inches='tight')
465
+ plt.close()
466
+
467
+ logger.info(f"Visualisations saved to {plots_dir}")
468
+
469
+ except Exception as e:
470
+ logger.warning(f"Error creating visualisations: {e}")
471
+
472
+ def _log_analysis_results(
473
+ self,
474
+ corr_matrix: pd.DataFrame,
475
+ high_correlations: List[Dict[str, Any]],
476
+ target_correlations: List[Dict[str, Any]],
477
+ vif_results: Dict[str, Any]
478
+ ) -> None:
479
+ """Log analysis results"""
480
+ logger.info("\n" + "="*80)
481
+ logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS REPORT")
482
+ logger.info("="*80)
483
+
484
+ # General information
485
+ logger.info(f"\n📊 GENERAL INFORMATION:")
486
+ logger.info(f" Correlation matrix size: {corr_matrix.shape}")
487
+ logger.info(f" Total features: {len(corr_matrix.columns)}")
488
+
489
+ # High correlations
490
+ if high_correlations:
491
+ logger.info(f"\n⚠ HIGH CORRELATIONS (|r| > 0.8): {len(high_correlations)}")
492
+ logger.info(" " + "-" * 60)
493
+
494
+ for i, corr in enumerate(high_correlations[:10]):
495
+ sign = "🟥" if corr['correlation'] < 0 else "🟩"
496
+ logger.info(f" {i+1:2d}. {sign} {corr['feature1']:25s} ↔ {corr['feature2']:25s}: {corr['correlation']:7.4f}")
497
+
498
+ if len(high_correlations) > 10:
499
+ logger.info(f" ... and {len(high_correlations) - 10} more pairs")
500
+
501
+ # Target variable correlations
502
+ if target_correlations:
503
+ logger.info(f"\n🎯 CORRELATIONS WITH TARGET VARIABLE:")
504
+ logger.info(" " + "-" * 60)
505
+
506
+ for i, corr in enumerate(target_correlations[:10]):
507
+ direction = "↓" if corr['correlation'] < 0 else "↑"
508
+ logger.info(f" {i+1:2d}. {direction} {corr['feature']:35s}: {corr['correlation']:7.4f}")
509
+
510
+ # Multicollinearity analysis
511
+ if vif_results['scores']:
512
+ logger.info(f"\n📈 MULTICOLLINEARITY ANALYSIS (VIF):")
513
+ logger.info(" " + "-" * 60)
514
+ logger.info(f" Critical (VIF > 100): {vif_results['summary']['critical']}")
515
+ logger.info(f" High (10 < VIF ≤ 100): {vif_results['summary']['high']}")
516
+ logger.info(f" Medium (5 < VIF ≤ 10): {vif_results['summary']['medium']}")
517
+ logger.info(f" Low (VIF ≤ 5): {vif_results['summary']['low']}")
518
+
519
+ # Top problematic features
520
+ if vif_results['issues']:
521
+ logger.info(f"\n🔴 PROBLEMATIC FEATURES (VIF > 10):")
522
+ for issue in vif_results['issues'][:10]:
523
+ logger.info(f" • {issue['feature']:35s}: VIF = {issue['vif']:7.1f} ({issue['severity']})")
524
+
525
+ logger.info("\n" + "="*80)
526
+ logger.info("RECOMMENDATIONS:")
527
+ logger.info("="*80)
528
+
529
+ # Generate recommendations
530
+ recommendations = []
531
+
532
+ if len(high_correlations) > 20:
533
+ recommendations.append("1. Remove highly correlated features (correlation method)")
534
+
535
+ if vif_results['summary']['critical'] > 0:
536
+ recommendations.append("2. Remove features with critical VIF (>100)")
537
+
538
+ if vif_results['summary']['high'] > 5:
539
+ recommendations.append("3. Consider removing features with VIF > 10")
540
+
541
+ if not recommendations:
542
+ recommendations.append("1. Data in good condition, no serious issues detected")
543
+ recommendations.append("2. Proceed to modelling")
544
+
545
+ for i, rec in enumerate(recommendations, 1):
546
+ logger.info(f" {rec}")
547
+
548
+ logger.info("\n" + "="*80)
549
+
550
+ def remove_highly_correlated(
551
+ self,
552
+ data: pd.DataFrame,
553
+ threshold: float = 0.85,
554
+ method: str = 'variance',
555
+ keep_target: bool = True,
556
+ keep_features: List[str] = None
557
+ ) -> pd.DataFrame:
558
+ """
559
+ Remove highly correlated features
560
+
561
+ Parameters:
562
+ -----------
563
+ data : pd.DataFrame
564
+ Source data
565
+ threshold : float
566
+ Correlation threshold for removal
567
+ method : str
568
+ Feature selection method for removal: 'variance', 'random', 'importance'
569
+ keep_target : bool
570
+ Whether to keep target variable
571
+ keep_features : List[str], optional
572
+ Features to keep
573
+
574
+ Returns:
575
+ --------
576
+ pd.DataFrame
577
+ Data after removing highly correlated features
578
+ """
579
+ logger.info("\n" + "="*80)
580
+ logger.info("REMOVING HIGHLY CORRELATED FEATURES")
581
+ logger.info("="*80)
582
+
583
+ data_clean = data.copy()
584
+
585
+ if 'pearson' not in self.correlation_matrices:
586
+ logger.warning("Correlation matrix not calculated, run analyze() first")
587
+ return data_clean
588
+
589
+ corr_matrix = self.correlation_matrices['pearson']
590
+
591
+ # Features to keep
592
+ features_to_keep = set()
593
+
594
+ if keep_target and self.config.target_column in data_clean.columns:
595
+ features_to_keep.add(self.config.target_column)
596
+
597
+ if keep_features:
598
+ for feat in keep_features:
599
+ if feat in data_clean.columns:
600
+ features_to_keep.add(feat)
601
+
602
+ # Temporal features (usually important for time series)
603
+ temporal_patterns = ['year', 'month', 'day', 'week', 'quarter',
604
+ 'hour', 'minute', 'second', 'sin', 'cos']
605
+
606
+ for col in data_clean.columns:
607
+ if any(pattern in col.lower() for pattern in temporal_patterns):
608
+ features_to_keep.add(col)
609
+
610
+ # Find highly correlated pairs
611
+ upper_triangle = corr_matrix.where(
612
+ np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
613
+ )
614
+
615
+ # Collect highly correlated features
616
+ correlated_features = set()
617
+ for col in upper_triangle.columns:
618
+ if col in features_to_keep:
619
+ continue
620
+
621
+ high_corr = upper_triangle[col][abs(upper_triangle[col]) > threshold]
622
+ for row_idx, corr_value in high_corr.items():
623
+ if not pd.isna(corr_value) and row_idx not in features_to_keep:
624
+ # Select which feature to remove
625
+ if method == 'variance':
626
+ # Remove the one with lower variance
627
+ var_col = data_clean[col].var()
628
+ var_row = data_clean[row_idx].var()
629
+ feature_to_remove = col if var_col < var_row else row_idx
630
+ elif method == 'importance':
631
+ # Remove the one with lower correlation to target variable
632
+ if self.config.target_column in corr_matrix.columns:
633
+ corr_col_target = abs(corr_matrix.loc[col, self.config.target_column])
634
+ corr_row_target = abs(corr_matrix.loc[row_idx, self.config.target_column])
635
+ feature_to_remove = col if corr_col_target < corr_row_target else row_idx
636
+ else:
637
+ # If no target, remove randomly
638
+ feature_to_remove = np.random.choice([col, row_idx])
639
+ else:
640
+ # Remove randomly
641
+ feature_to_remove = np.random.choice([col, row_idx])
642
+
643
+ correlated_features.add(feature_to_remove)
644
+
645
+ # Remove features
646
+ features_to_remove = list(correlated_features)
647
+
648
+ if features_to_remove:
649
+ data_clean = data_clean.drop(columns=features_to_remove)
650
+
651
+ logger.info(f"\n📊 REMOVAL RESULTS:")
652
+ logger.info(f" Initial feature count: {len(data.columns)}")
653
+ logger.info(f" Features removed: {len(features_to_remove)}")
654
+ logger.info(f" Final feature count: {len(data_clean.columns)}")
655
+ logger.info(f" Retained: {len(data_clean.columns)/len(data.columns)*100:.1f}%")
656
+
657
+ if features_to_remove:
658
+ logger.info(f"\n🗑️ REMOVED FEATURES:")
659
+ for i, feat in enumerate(sorted(features_to_remove)[:20]):
660
+ logger.info(f" {i+1:2d}. {feat}")
661
+ if len(features_to_remove) > 20:
662
+ logger.info(f" ... and {len(features_to_remove) - 20} more features")
663
+ else:
664
+ logger.info("✓ No highly correlated features detected, all features retained")
665
+
666
+ logger.info("="*80)
667
+ return data_clean
668
+
669
+ def get_report(self) -> Dict[str, Any]:
670
+ """Get analysis report"""
671
+ report = {
672
+ "correlation_matrix_shape": None,
673
+ "high_correlation_count": 0,
674
+ "vif_summary": {},
675
+ "target_correlation_count": 0
676
+ }
677
+
678
+ if 'pearson' in self.correlation_matrices:
679
+ report["correlation_matrix_shape"] = self.correlation_matrices['pearson'].shape
680
+
681
+ if 'pearson' in self.high_correlation_pairs:
682
+ report["high_correlation_count"] = len(self.high_correlation_pairs['pearson'])
683
+
684
+ if self.vif_scores:
685
+ report["vif_summary"] = self.vif_scores.get('summary', {})
686
+
687
+ return report
data_loader/__init__.py ADDED
File without changes
data_loader/data_loader.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 2: DATA LOADER
3
+ # ============================================
4
+ from datetime import datetime
5
+ import hashlib
6
+ import json
7
+ import traceback
8
+ from typing import Dict, List, Optional
9
+ from venv import logger
10
+ from config.config import Config, DataType
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ class DataLoader:
15
+ """Class for loading and initial data processing"""
16
+
17
+ def __init__(self, config: Config):
18
+ """
19
+ Initialise data loader
20
+
21
+ Parameters:
22
+ -----------
23
+ config : Config
24
+ Experiment configuration
25
+ """
26
+ self.config = config
27
+ self.data = None
28
+ self.metadata = {}
29
+ self.data_hash = None
30
+ self.loading_time = None
31
+ self.data_types = {}
32
+ self.original_shape = None
33
+
34
+ def load_from_csv(
35
+ self,
36
+ data_path: Optional[str] = None,
37
+ parse_dates: List[str] = None,
38
+ date_format: str = None,
39
+ dtype: Dict = None,
40
+ **kwargs
41
+ ) -> pd.DataFrame:
42
+ """
43
+ Load data from CSV file
44
+
45
+ Parameters:
46
+ -----------
47
+ data_path : str, optional
48
+ Path to CSV file. If None, uses path from configuration.
49
+ parse_dates : List[str], optional
50
+ List of columns to parse as dates
51
+ date_format : str, optional
52
+ Date format
53
+ dtype : Dict, optional
54
+ Data types for columns
55
+ **kwargs : dict
56
+ Additional parameters for pd.read_csv
57
+
58
+ Returns:
59
+ --------
60
+ pd.DataFrame
61
+ Loaded data
62
+ """
63
+ logger.info("="*80)
64
+ logger.info("LOADING DATA FROM CSV")
65
+ logger.info("="*80)
66
+
67
+ start_time = datetime.now()
68
+
69
+ try:
70
+ path = data_path or self.config.data_path
71
+
72
+ if parse_dates is None:
73
+ parse_dates = ['date']
74
+
75
+ # Load data
76
+ self.data = pd.read_csv(
77
+ path,
78
+ parse_dates=parse_dates,
79
+ dayfirst=False,
80
+ dtype=dtype,
81
+ **kwargs
82
+ )
83
+
84
+ # Convert dates if needed
85
+ for date_col in parse_dates:
86
+ if date_col in self.data.columns:
87
+ if date_format:
88
+ self.data[date_col] = pd.to_datetime(
89
+ self.data[date_col],
90
+ format=date_format,
91
+ errors='coerce'
92
+ )
93
+ else:
94
+ self.data[date_col] = pd.to_datetime(
95
+ self.data[date_col],
96
+ errors='coerce'
97
+ )
98
+
99
+ # Save original shape
100
+ self.original_shape = self.data.shape
101
+
102
+ # Filter by years
103
+ if 'date' in self.data.columns:
104
+ mask = (self.data['date'].dt.year >= self.config.start_year) & \
105
+ (self.data['date'].dt.year <= self.config.end_year)
106
+ self.data = self.data.loc[mask].copy()
107
+
108
+ # Sort by date
109
+ if 'date' in self.data.columns:
110
+ self.data = self.data.sort_values('date').reset_index(drop=True)
111
+ # Set date as index
112
+ self.data.set_index('date', inplace=True)
113
+
114
+ # Calculate data hash
115
+ self.data_hash = self._calculate_data_hash()
116
+
117
+ # Analyse data types
118
+ self._analyse_data_types()
119
+
120
+ # Save metadata
121
+ self._save_metadata()
122
+
123
+ # Loading time
124
+ self.loading_time = (datetime.now() - start_time).total_seconds()
125
+
126
+ logger.info(f"✓ Loaded {len(self.data)} records, {len(self.data.columns)} columns")
127
+ logger.info(f" Period: {self.data.index.min()} - {self.data.index.max()}")
128
+ logger.info(f" Data types: {self.data_types}")
129
+ logger.info(f" Target variable: {self.config.target_column}")
130
+ logger.info(f" Loading time: {self.loading_time:.2f} sec")
131
+
132
+ return self.data
133
+
134
+ except Exception as e:
135
+ logger.error(f"✗ Error loading data: {e}")
136
+ logger.error(traceback.format_exc())
137
+ raise
138
+
139
+ def create_synthetic_data(
140
+ self,
141
+ n_days: int = 365*21,
142
+ trend_strength: float = 0.01,
143
+ seasonal_amplitude: List[float] = None,
144
+ noise_std: float = 10,
145
+ include_exogenous: bool = True,
146
+ random_state: int = 42
147
+ ) -> pd.DataFrame:
148
+ """
149
+ Create synthetic data for testing
150
+
151
+ Parameters:
152
+ -----------
153
+ n_days : int
154
+ Number of days to generate
155
+ trend_strength : float
156
+ Trend strength
157
+ seasonal_amplitude : List[float], optional
158
+ Seasonal component amplitudes
159
+ noise_std : float
160
+ Noise standard deviation
161
+ include_exogenous : bool
162
+ Whether to include exogenous variables
163
+ random_state : int
164
+ Seed for reproducibility
165
+
166
+ Returns:
167
+ --------
168
+ pd.DataFrame
169
+ Synthetic data
170
+ """
171
+ logger.info("="*80)
172
+ logger.info("CREATING SYNTHETIC DATA")
173
+ logger.info("="*80)
174
+
175
+ if seasonal_amplitude is None:
176
+ seasonal_amplitude = [50, 30, 20]
177
+
178
+ np.random.seed(random_state)
179
+
180
+ # Generate dates
181
+ dates = pd.date_range(
182
+ start=f'{self.config.start_year}-01-01',
183
+ periods=n_days,
184
+ freq='D'
185
+ )
186
+
187
+ t = np.arange(n_days)
188
+
189
+ # Base components
190
+ trend = trend_strength * t
191
+
192
+ # Seasonal components
193
+ seasonal = 0
194
+ periods = [365, 30, 7] # yearly, monthly, weekly seasonality
195
+ for i, (period, amplitude) in enumerate(zip(periods, seasonal_amplitude)):
196
+ seasonal += amplitude * np.sin(2 * np.pi * t / period)
197
+ if i < len(seasonal_amplitude) - 1:
198
+ seasonal += 0.5 * amplitude * np.cos(4 * np.pi * t / period)
199
+
200
+ # Cyclical component (business cycles)
201
+ cycle = 20 * np.sin(2 * np.pi * t / (365*5)) # 5-year cycle
202
+
203
+ # Noise
204
+ noise = np.random.normal(0, noise_std, n_days)
205
+
206
+ # Generate target variable
207
+ raskhodvoda = 100 + trend + seasonal + cycle + noise
208
+
209
+ # Create DataFrame
210
+ self.data = pd.DataFrame(
211
+ index=dates,
212
+ data={'raskhodvoda': raskhodvoda}
213
+ )
214
+
215
+ # Generate exogenous variables
216
+ if include_exogenous:
217
+ # Temperature with seasonality
218
+ tavg = 10 + 8 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 3, n_days)
219
+ tmin = tavg - 5 + np.random.normal(0, 2, n_days)
220
+ tmax = tavg + 5 + np.random.normal(0, 2, n_days)
221
+
222
+ # Water level with trend and seasonality
223
+ urovenvoda = 200 + 0.5 * t + 20 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 5, n_days)
224
+
225
+ # Add to DataFrame
226
+ self.data['tavg'] = tavg
227
+ self.data['tmin'] = tmin
228
+ self.data['tmax'] = tmax
229
+ self.data['urovenvoda'] = urovenvoda
230
+
231
+ # Add noisy lags
232
+ for lag in [1, 7, 30]:
233
+ self.data[f'tavg_lag_{lag}'] = self.data['tavg'].shift(lag) + np.random.normal(0, 1, n_days)
234
+
235
+ # Add missing values and outliers for testing
236
+ if n_days > 100:
237
+ # Missing values (5% of data)
238
+ mask_missing = np.random.random(n_days) < 0.05
239
+ self.data.loc[mask_missing, 'tavg'] = np.nan
240
+
241
+ # Outliers (1% of data)
242
+ mask_outliers = np.random.random(n_days) < 0.01
243
+ self.data.loc[mask_outliers, 'raskhodvoda'] *= 2
244
+
245
+ # Save metadata
246
+ self.metadata.update({
247
+ 'is_synthetic': True,
248
+ 'synthetic_params': {
249
+ 'n_days': n_days,
250
+ 'trend_strength': trend_strength,
251
+ 'seasonal_amplitude': seasonal_amplitude,
252
+ 'noise_std': noise_std,
253
+ 'include_exogenous': include_exogenous,
254
+ 'random_state': random_state
255
+ }
256
+ })
257
+
258
+ logger.info(f"✓ Created {len(self.data)} synthetic records")
259
+ logger.info(f" Columns: {list(self.data.columns)}")
260
+
261
+ return self.data
262
+
263
+ def _calculate_data_hash(self) -> str:
264
+ """Calculate data hash for tracking changes"""
265
+ if self.data is None:
266
+ return None
267
+
268
+ # Use hash of first 1000 rows and metadata
269
+ sample = self.data.head(1000).to_string().encode()
270
+ return hashlib.md5(sample).hexdigest()
271
+
272
+ def _analyse_data_types(self) -> None:
273
+ """Analyse data types in DataFrame"""
274
+ if self.data is None:
275
+ return
276
+
277
+ for col in self.data.columns:
278
+ dtype = str(self.data[col].dtype)
279
+
280
+ if 'datetime' in dtype:
281
+ self.data_types[col] = DataType.TEMPORAL.value
282
+ elif 'int' in dtype or 'float' in dtype:
283
+ self.data_types[col] = DataType.NUMERIC.value
284
+ elif 'object' in dtype or 'category' in dtype:
285
+ # Check if categorical
286
+ unique_ratio = self.data[col].nunique() / len(self.data)
287
+ if unique_ratio < 0.1: # Less than 10% unique values
288
+ self.data_types[col] = DataType.CATEGORICAL.value
289
+ else:
290
+ self.data_types[col] = DataType.TEXT.value
291
+ else:
292
+ self.data_types[col] = 'unknown'
293
+
294
+ def _save_metadata(self) -> None:
295
+ """Save data metadata"""
296
+ if self.data is None:
297
+ return
298
+
299
+ # Basic metadata
300
+ self.metadata.update({
301
+ 'original_shape': list(self.original_shape) if self.original_shape else [],
302
+ 'current_shape': list(self.data.shape),
303
+ 'columns': list(self.data.columns),
304
+ 'data_types': self.data_types,
305
+ 'date_range': {
306
+ 'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
307
+ 'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
308
+ },
309
+ 'data_hash': self.data_hash,
310
+ 'loading_time': self.loading_time
311
+ })
312
+
313
+ # Statistics for numeric columns
314
+ numeric_cols = self.data.select_dtypes(include=[np.number]).columns
315
+ if len(numeric_cols) > 0:
316
+ stats = self.data[numeric_cols].describe().to_dict()
317
+ # Add additional statistics
318
+ for col in numeric_cols:
319
+ stats[col]['skewness'] = float(self.data[col].skew())
320
+ stats[col]['kurtosis'] = float(self.data[col].kurtosis())
321
+ stats[col]['cv'] = float(self.data[col].std() / self.data[col].mean()) if self.data[col].mean() != 0 else np.nan
322
+
323
+ self.metadata['numeric_statistics'] = stats
324
+
325
+ # Missing values information
326
+ missing_info = {
327
+ 'total_missing': int(self.data.isnull().sum().sum()),
328
+ 'missing_by_column': self.data.isnull().sum().to_dict(),
329
+ 'missing_percentage': (self.data.isnull().sum() / len(self.data) * 100).to_dict(),
330
+ 'rows_with_missing': int(self.data.isnull().any(axis=1).sum()),
331
+ 'columns_with_missing': self.data.columns[self.data.isnull().any()].tolist()
332
+ }
333
+ self.metadata['missing_info'] = missing_info
334
+
335
+ def get_data_info(self) -> Dict:
336
+ """Get information about data"""
337
+ if self.data is None:
338
+ return {}
339
+
340
+ info = {
341
+ 'shape': list(self.data.shape),
342
+ 'columns': list(self.data.columns),
343
+ 'data_types': self.data_types,
344
+ 'date_range': {
345
+ 'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
346
+ 'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
347
+ },
348
+ 'target_column': self.config.target_column,
349
+ 'numeric_columns': self.data.select_dtypes(include=[np.number]).columns.tolist(),
350
+ 'categorical_columns': [col for col, dtype in self.data_types.items()
351
+ if dtype == DataType.CATEGORICAL.value],
352
+ 'missing_info': self.metadata.get('missing_info', {})
353
+ }
354
+
355
+ return info
356
+
357
+ def save_raw_data_info(self) -> None:
358
+ """Save raw data information"""
359
+ if self.data is None:
360
+ return
361
+
362
+ info_path = f'{self.config.results_dir}/reports/raw_data_info.json'
363
+
364
+ # Custom JSON encoder for handling numpy types
365
+ class NumpyEncoder(json.JSONEncoder):
366
+ def default(self, obj):
367
+ if isinstance(obj, (np.integer, np.floating)):
368
+ if np.isnan(obj):
369
+ return None
370
+ return float(obj)
371
+ elif isinstance(obj, np.bool_):
372
+ return bool(obj)
373
+ elif isinstance(obj, np.ndarray):
374
+ return obj.tolist()
375
+ elif isinstance(obj, pd.Timestamp):
376
+ return obj.strftime('%Y-%m-%d %H:%M:%S')
377
+ elif isinstance(obj, pd.Period):
378
+ return str(obj)
379
+ return super().default(obj)
380
+
381
+ with open(info_path, 'w', encoding='utf-8') as f:
382
+ json.dump(self.metadata, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
383
+
384
+ logger.info(f"✓ Raw data information saved: {info_path}")
385
+
386
+ def resample_data(
387
+ self,
388
+ freq: str = None,
389
+ method: str = 'mean'
390
+ ) -> pd.DataFrame:
391
+ """
392
+ Resample time series data
393
+
394
+ Parameters:
395
+ -----------
396
+ freq : str, optional
397
+ New frequency (e.g., 'D', 'W', 'M')
398
+ method : str
399
+ Aggregation method: 'mean', 'sum', 'last', 'first'
400
+
401
+ Returns:
402
+ --------
403
+ pd.DataFrame
404
+ Resampled data
405
+ """
406
+ if self.data is None:
407
+ logger.warning("Data not loaded")
408
+ return None
409
+
410
+ freq = freq or self.config.freq
411
+
412
+ # Check if index is datetime
413
+ if not isinstance(self.data.index, pd.DatetimeIndex):
414
+ logger.error("Data index is not DatetimeIndex")
415
+ return self.data
416
+
417
+ # Aggregation methods
418
+ agg_methods = {
419
+ 'mean': np.mean,
420
+ 'sum': np.sum,
421
+ 'last': lambda x: x.iloc[-1],
422
+ 'first': lambda x: x.iloc[0],
423
+ 'min': np.min,
424
+ 'max': np.max,
425
+ 'median': np.median
426
+ }
427
+
428
+ if method not in agg_methods:
429
+ logger.warning(f"Method {method} not supported, using mean")
430
+ method = 'mean'
431
+
432
+ # Resampling
433
+ try:
434
+ if method == 'last':
435
+ resampled_data = self.data.resample(freq).last()
436
+ elif method == 'first':
437
+ resampled_data = self.data.resample(freq).first()
438
+ else:
439
+ resampled_data = self.data.resample(freq).agg(agg_methods[method])
440
+
441
+ logger.info(f"Data resampled to frequency {freq}, method {method}")
442
+ logger.info(f"Size before: {len(self.data)}, after: {len(resampled_data)}")
443
+
444
+ self.data = resampled_data
445
+ return self.data
446
+
447
+ except Exception as e:
448
+ logger.error(f"Error during resampling: {e}")
449
+ return self.data
450
+
451
+ def detect_frequency(self) -> str:
452
+ """
453
+ Automatically detect data frequency
454
+
455
+ Returns:
456
+ --------
457
+ str
458
+ Detected data frequency
459
+ """
460
+ if self.data is None or len(self.data) < 2:
461
+ return 'unknown'
462
+
463
+ if not isinstance(self.data.index, pd.DatetimeIndex):
464
+ return 'irregular'
465
+
466
+ # Calculate differences between timestamps
467
+ diffs = pd.Series(self.data.index).diff().dropna()
468
+
469
+ if len(diffs) == 0:
470
+ return 'unknown'
471
+
472
+ # Most frequent difference
473
+ mode_diff = diffs.mode().iloc[0] if not diffs.mode().empty else diffs.iloc[0]
474
+
475
+ # Determine frequency
476
+ if mode_diff < pd.Timedelta('1 hour'):
477
+ return 'H' # Hourly
478
+ elif mode_diff < pd.Timedelta('1 day'):
479
+ return 'D' # Daily
480
+ elif mode_diff < pd.Timedelta('7 days'):
481
+ return 'W' # Weekly
482
+ elif mode_diff < pd.Timedelta('30 days'):
483
+ return 'M' # Monthly
484
+ elif mode_diff < pd.Timedelta('90 days'):
485
+ return 'Q' # Quarterly
486
+ else:
487
+ return 'Y' # Yearly
decomposition/__init__.py ADDED
File without changes
decomposition/decomposer.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 7: TIME SERIES DECOMPOSITION
3
+ # ============================================
4
+ import traceback
5
+ from typing import Dict, Optional
6
+ from venv import logger
7
+
8
+ from config.config import Config
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import statsmodels.api as sm
14
+ from scipy import stats
15
+ from statsmodels.tsa.seasonal import seasonal_decompose, STL
16
+ from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
17
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
18
+ from statsmodels.stats.diagnostic import acorr_ljungbox
19
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
20
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing
21
+
22
+
23
+ class TimeSeriesDecomposer:
24
+ """Class for time series decomposition"""
25
+
26
+ def __init__(self, config: Config):
27
+ """
28
+ Initialise decomposer
29
+
30
+ Parameters:
31
+ -----------
32
+ config : Config
33
+ Experiment configuration
34
+ """
35
+ self.config = config
36
+ self.decomposition_results = {}
37
+ self.decomposition_models = {}
38
+ self.seasonal_periods = {}
39
+
40
+ def decompose(
41
+ self,
42
+ data: pd.DataFrame,
43
+ target_col: Optional[str] = None,
44
+ method: str = 'stl',
45
+ period: Optional[int] = None,
46
+ **kwargs
47
+ ) -> Dict:
48
+ """
49
+ Decompose time series into components
50
+
51
+ Parameters:
52
+ -----------
53
+ data : pd.DataFrame
54
+ Input data
55
+ target_col : str, optional
56
+ Target variable. If None, uses configuration value.
57
+ method : str
58
+ Decomposition model: 'stl', 'seasonal_decompose', 'mstl', 'naive'
59
+ period : int, optional
60
+ Seasonality period. If None, uses configuration value.
61
+ **kwargs : dict
62
+ Additional parameters for method
63
+
64
+ Returns:
65
+ --------
66
+ Dict
67
+ Decomposition results
68
+ """
69
+ logger.info("\n" + "="*80)
70
+ logger.info("TIME SERIES DECOMPOSITION")
71
+ logger.info("="*80)
72
+
73
+ target_col = target_col or self.config.target_column
74
+ period = period or self.config.seasonal_period
75
+
76
+ if target_col not in data.columns:
77
+ logger.error(f"Target variable '{target_col}' not found")
78
+ return {}
79
+
80
+ # Set date as index if not set
81
+ if not isinstance(data.index, pd.DatetimeIndex):
82
+ if 'date' in data.columns:
83
+ data = data.set_index('date')
84
+ else:
85
+ logger.error("DatetimeIndex required for decomposition")
86
+ return {}
87
+
88
+ series = data[target_col]
89
+
90
+ # Automatic seasonality period detection
91
+ if period is None or period == 'auto':
92
+ period = self._detect_seasonal_period(series)
93
+ logger.info(f"Automatically detected seasonality period: {period}")
94
+
95
+ try:
96
+ decomposition_result = None
97
+
98
+ if method == 'stl':
99
+ decomposition_result = self._stl_decomposition(series, period, **kwargs)
100
+ elif method == 'seasonal_decompose':
101
+ decomposition_result = self._seasonal_decompose(series, period, **kwargs)
102
+ elif method == 'mstl':
103
+ decomposition_result = self._mstl_decomposition(series, **kwargs)
104
+ elif method == 'naive':
105
+ decomposition_result = self._naive_decomposition(series, period, **kwargs)
106
+ else:
107
+ logger.warning(f"Method {method} not supported, using STL")
108
+ decomposition_result = self._stl_decomposition(series, period, **kwargs)
109
+
110
+ if decomposition_result is None:
111
+ logger.error("Decomposition failed")
112
+ return {}
113
+
114
+ # Analyse residuals
115
+ residuals_info = self._analyse_residuals(decomposition_result.get('residual', None))
116
+
117
+ # Analyse seasonality
118
+ seasonal_info = self._analyse_seasonality(
119
+ decomposition_result.get('seasonal', None),
120
+ period
121
+ )
122
+
123
+ # Save results
124
+ self.decomposition_results[target_col] = {
125
+ 'method': method,
126
+ 'period': period,
127
+ 'residuals_analysis': residuals_info,
128
+ 'seasonality_analysis': seasonal_info,
129
+ 'components_present': list(decomposition_result.keys()),
130
+ 'decomposition_stats': {
131
+ 'trend_strength': self._calculate_trend_strength(
132
+ decomposition_result.get('trend', None),
133
+ decomposition_result.get('residual', None)
134
+ ),
135
+ 'seasonal_strength': self._calculate_seasonal_strength(
136
+ decomposition_result.get('seasonal', None),
137
+ decomposition_result.get('residual', None)
138
+ )
139
+ }
140
+ }
141
+
142
+ # Visualisation
143
+ if self.config.save_plots:
144
+ self._plot_decomposition(data, target_col, decomposition_result, method, period)
145
+
146
+ # Additional visualisation
147
+ if residuals_info:
148
+ self._plot_residuals_analysis(decomposition_result.get('residual', None), target_col)
149
+
150
+ return self.decomposition_results[target_col]
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error during decomposition: {e}")
154
+ logger.error(traceback.format_exc())
155
+ return {}
156
+
157
+ def _detect_seasonal_period(self, series: pd.Series) -> int:
158
+ """Automatic seasonality period detection"""
159
+ if len(series) < 100:
160
+ return self.config.seasonal_period
161
+
162
+ try:
163
+ # Use autocorrelation to determine period
164
+ acf_values = acf(series.dropna(), nlags=min(500, len(series)//2))
165
+
166
+ # Find peaks in autocorrelation
167
+ peaks = []
168
+ for i in range(1, len(acf_values)-1):
169
+ if acf_values[i] > acf_values[i-1] and acf_values[i] > acf_values[i+1]:
170
+ if acf_values[i] > 0.3: # Significance threshold
171
+ peaks.append(i)
172
+
173
+ if peaks:
174
+ # Take most significant period
175
+ dominant_period = peaks[0]
176
+
177
+ # Check for multiple periods
178
+ for period in [7, 30, 90, 365]:
179
+ if abs(dominant_period - period) <= 2:
180
+ return period
181
+
182
+ return dominant_period
183
+
184
+ return self.config.seasonal_period
185
+
186
+ except:
187
+ return self.config.seasonal_period
188
+
189
+ def _stl_decomposition(
190
+ self,
191
+ series: pd.Series,
192
+ period: int,
193
+ **kwargs
194
+ ) -> Optional[Dict]:
195
+ """STL decomposition"""
196
+ try:
197
+ if len(series) < 2 * period:
198
+ logger.warning(f"Insufficient data for STL decomposition with period {period}")
199
+ return self._seasonal_decompose(series, period, **kwargs)
200
+
201
+ # STL decomposition
202
+ stl = STL(
203
+ series,
204
+ period=period,
205
+ seasonal=kwargs.get('seasonal', 7),
206
+ trend=kwargs.get('trend', None),
207
+ robust=kwargs.get('robust', True),
208
+ seasonal_deg=kwargs.get('seasonal_deg', 1),
209
+ trend_deg=kwargs.get('trend_deg', 1),
210
+ low_pass_deg=kwargs.get('low_pass_deg', 1)
211
+ )
212
+
213
+ result = stl.fit()
214
+
215
+ return {
216
+ 'trend': result.trend,
217
+ 'seasonal': result.seasonal,
218
+ 'residual': result.resid,
219
+ 'observed': series
220
+ }
221
+
222
+ except Exception as e:
223
+ logger.warning(f"STL decomposition failed: {e}")
224
+ return self._seasonal_decompose(series, period, **kwargs)
225
+
226
+ def _seasonal_decompose(
227
+ self,
228
+ series: pd.Series,
229
+ period: int,
230
+ **kwargs
231
+ ) -> Optional[Dict]:
232
+ """Seasonal decomposition from statsmodels"""
233
+ try:
234
+ model = kwargs.get('model', 'additive')
235
+
236
+ if len(series) < 2 * period:
237
+ # Reduce period if insufficient data
238
+ period = max(7, len(series) // 4)
239
+
240
+ decomposition = seasonal_decompose(
241
+ series,
242
+ model=model,
243
+ period=period,
244
+ extrapolate_trend=kwargs.get('extrapolate_trend', 'freq'),
245
+ two_sided=kwargs.get('two_sided', True)
246
+ )
247
+
248
+ return {
249
+ 'trend': decomposition.trend,
250
+ 'seasonal': decomposition.seasonal,
251
+ 'residual': decomposition.resid,
252
+ 'observed': series
253
+ }
254
+
255
+ except Exception as e:
256
+ logger.warning(f"Seasonal decompose failed: {e}")
257
+ return self._naive_decomposition(series, period, **kwargs)
258
+
259
+ def _mstl_decomposition(
260
+ self,
261
+ series: pd.Series,
262
+ **kwargs
263
+ ) -> Optional[Dict]:
264
+ """Multi-seasonal decomposition (simplified)"""
265
+ try:
266
+ # Simplified MSTL version
267
+ periods = kwargs.get('periods', [7, 365])
268
+
269
+ result = {
270
+ 'observed': series,
271
+ 'trend': None,
272
+ 'seasonal': pd.Series(0, index=series.index),
273
+ 'residual': series.copy()
274
+ }
275
+
276
+ # Sequentially remove seasonal components
277
+ for period in periods:
278
+ if len(series) >= 2 * period:
279
+ try:
280
+ decomp = seasonal_decompose(
281
+ result['residual'],
282
+ model='additive',
283
+ period=period,
284
+ extrapolate_trend='freq'
285
+ )
286
+
287
+ if result['trend'] is None:
288
+ result['trend'] = decomp.trend
289
+
290
+ result['seasonal'] = result['seasonal'] + decomp.seasonal
291
+ result['residual'] = decomp.resid
292
+ except:
293
+ continue
294
+
295
+ if result['trend'] is None:
296
+ result['trend'] = series.rolling(window=min(365, len(series)//4), center=True).mean()
297
+
298
+ return result
299
+
300
+ except Exception as e:
301
+ logger.warning(f"MSTL decomposition failed: {e}")
302
+ return self._seasonal_decompose(series, 365, **kwargs)
303
+
304
+ def _naive_decomposition(
305
+ self,
306
+ series: pd.Series,
307
+ period: int,
308
+ **kwargs
309
+ ) -> Optional[Dict]:
310
+ """Naive decomposition"""
311
+ try:
312
+ # Simple decomposition using moving averages
313
+ trend = series.rolling(
314
+ window=min(period, len(series)//4),
315
+ center=True,
316
+ min_periods=1
317
+ ).mean()
318
+
319
+ # Seasonal component
320
+ if period > 1:
321
+ # Average by seasons
322
+ seasonal = series.groupby(series.index.dayofyear if period == 365 else
323
+ series.index.dayofweek if period == 7 else
324
+ series.index.month).transform('mean')
325
+ seasonal = seasonal - seasonal.mean()
326
+ else:
327
+ seasonal = pd.Series(0, index=series.index)
328
+
329
+ residual = series - trend - seasonal
330
+
331
+ return {
332
+ 'trend': trend,
333
+ 'seasonal': seasonal,
334
+ 'residual': residual,
335
+ 'observed': series
336
+ }
337
+
338
+ except Exception as e:
339
+ logger.error(f"Naive decomposition failed: {e}")
340
+ return None
341
+
342
+ def _analyse_residuals(self, residuals) -> Dict:
343
+ """Analyse decomposition residuals"""
344
+ if residuals is None:
345
+ return {}
346
+
347
+ residuals_clean = residuals.dropna()
348
+
349
+ if len(residuals_clean) == 0:
350
+ return {}
351
+
352
+ stats_info = {
353
+ 'mean': float(residuals_clean.mean()),
354
+ 'std': float(residuals_clean.std()),
355
+ 'skewness': float(residuals_clean.skew()),
356
+ 'kurtosis': float(residuals_clean.kurtosis()),
357
+ 'min': float(residuals_clean.min()),
358
+ 'max': float(residuals_clean.max()),
359
+ 'mad': float((residuals_clean - residuals_clean.mean()).abs().mean()),
360
+ 'normality_tests': {},
361
+ 'autocorrelation_tests': {}
362
+ }
363
+
364
+ # Normality test
365
+ if len(residuals_clean) > 3:
366
+ try:
367
+ # Shapiro-Wilk test
368
+ shapiro_stat, shapiro_p = stats.shapiro(residuals_clean.iloc[:5000])
369
+ stats_info['normality_tests']['shapiro_wilk'] = {
370
+ 'statistic': float(shapiro_stat),
371
+ 'pvalue': float(shapiro_p),
372
+ 'is_normal': shapiro_p > 0.05
373
+ }
374
+
375
+ # Anderson-Darling test
376
+ anderson_result = stats.anderson(residuals_clean, dist='norm')
377
+ stats_info['normality_tests']['anderson_darling'] = {
378
+ 'statistic': float(anderson_result.statistic),
379
+ 'critical_values': {str(level): float(value)
380
+ for level, value in zip(anderson_result.significance_level,
381
+ anderson_result.critical_values)},
382
+ 'is_normal': anderson_result.statistic < anderson_result.critical_values[2] # At 5% level
383
+ }
384
+ except:
385
+ stats_info['normality_tests']['error'] = 'not enough data or calculation error'
386
+
387
+ # Autocorrelation test
388
+ try:
389
+ # Ljung-Box test
390
+ lb_test = acorr_ljungbox(residuals_clean, lags=[10, 20, 30], return_df=True)
391
+
392
+ autocorr_info = {}
393
+ for idx, row in lb_test.iterrows():
394
+ autocorr_info[f'lag_{int(row.name)}'] = {
395
+ 'statistic': float(row['lb_stat']),
396
+ 'pvalue': float(row['lb_pvalue']),
397
+ 'has_autocorrelation': row['lb_pvalue'] < 0.05
398
+ }
399
+
400
+ stats_info['autocorrelation_tests']['ljung_box'] = autocorr_info
401
+
402
+ # Durbin-Watson test
403
+ try:
404
+ dw_stat = sm.stats.stattools.durbin_watson(residuals_clean)
405
+ stats_info['autocorrelation_tests']['durbin_watson'] = {
406
+ 'statistic': float(dw_stat),
407
+ 'interpretation': 'no autocorrelation' if 1.5 < dw_stat < 2.5 else
408
+ 'positive autocorrelation' if dw_stat < 1.5 else
409
+ 'negative autocorrelation'
410
+ }
411
+ except:
412
+ pass
413
+
414
+ except:
415
+ stats_info['autocorrelation_tests']['error'] = 'calculation error'
416
+
417
+ # Heteroskedasticity test
418
+ try:
419
+ # ARCH test
420
+ from statsmodels.stats.diagnostic import het_arch
421
+ arch_test = het_arch(residuals_clean)
422
+ stats_info['heteroskedasticity_tests'] = {
423
+ 'arch': {
424
+ 'statistic': float(arch_test[0]),
425
+ 'pvalue': float(arch_test[1]),
426
+ 'is_homoskedastic': arch_test[1] > 0.05
427
+ }
428
+ }
429
+ except:
430
+ pass
431
+
432
+ return stats_info
433
+
434
+ def _analyse_seasonality(self, seasonal_component, period: int) -> Dict:
435
+ """Analyse seasonal component"""
436
+ if seasonal_component is None:
437
+ return {}
438
+
439
+ seasonal_clean = seasonal_component.dropna()
440
+
441
+ if len(seasonal_clean) == 0:
442
+ return {}
443
+
444
+ analysis = {
445
+ 'period': period,
446
+ 'amplitude': float(seasonal_clean.max() - seasonal_clean.min()),
447
+ 'mean_amplitude': float(seasonal_clean.abs().mean()),
448
+ 'seasonal_strength': float(seasonal_clean.std()),
449
+ 'periodicity_check': {}
450
+ }
451
+
452
+ # Check periodicity via autocorrelation
453
+ if len(seasonal_clean) > period * 2:
454
+ try:
455
+ acf_values = acf(seasonal_clean, nlags=min(period * 3, len(seasonal_clean)//2))
456
+
457
+ # Look for peaks at expected lags
458
+ expected_lags = [period, period*2]
459
+ peaks_found = []
460
+
461
+ for lag in expected_lags:
462
+ if lag < len(acf_values):
463
+ if acf_values[lag] > 0.5: # Strong autocorrelation at period
464
+ peaks_found.append({
465
+ 'lag': lag,
466
+ 'autocorrelation': float(acf_values[lag]),
467
+ 'is_significant': True
468
+ })
469
+
470
+ analysis['periodicity_check']['autocorrelation_peaks'] = peaks_found
471
+ analysis['periodicity_check']['is_periodic'] = len(peaks_found) > 0
472
+ except:
473
+ pass
474
+
475
+ # Seasonality pattern analysis
476
+ if isinstance(seasonal_clean.index, pd.DatetimeIndex):
477
+ try:
478
+ # Group by months/week days
479
+ if period == 12 or period == 365:
480
+ # Monthly seasonality
481
+ monthly_seasonal = seasonal_clean.groupby(seasonal_clean.index.month).mean()
482
+ analysis['monthly_pattern'] = monthly_seasonal.to_dict()
483
+
484
+ if period == 7 or period == 365:
485
+ # Daily seasonality
486
+ daily_seasonal = seasonal_clean.groupby(seasonal_clean.index.dayofweek).mean()
487
+ analysis['daily_pattern'] = daily_seasonal.to_dict()
488
+ except:
489
+ pass
490
+
491
+ return analysis
492
+
493
+ def _calculate_trend_strength(self, trend, residual) -> float:
494
+ """Calculate trend strength"""
495
+ if trend is None or residual is None:
496
+ return 0.0
497
+
498
+ trend_clean = trend.dropna()
499
+ residual_clean = residual.dropna()
500
+
501
+ if len(trend_clean) == 0 or len(residual_clean) == 0:
502
+ return 0.0
503
+
504
+ # Trend strength = 1 - Var(residual) / Var(trend + residual)
505
+ try:
506
+ var_total = np.var(trend_clean + residual_clean)
507
+ if var_total > 0:
508
+ trend_strength = 1 - np.var(residual_clean) / var_total
509
+ return max(0.0, min(1.0, float(trend_strength)))
510
+ except:
511
+ pass
512
+
513
+ return 0.0
514
+
515
+ def _calculate_seasonal_strength(self, seasonal, residual) -> float:
516
+ """Calculate seasonality strength"""
517
+ if seasonal is None or residual is None:
518
+ return 0.0
519
+
520
+ seasonal_clean = seasonal.dropna()
521
+ residual_clean = residual.dropna()
522
+
523
+ if len(seasonal_clean) == 0 or len(residual_clean) == 0:
524
+ return 0.0
525
+
526
+ # Seasonality strength = 1 - Var(residual) / Var(seasonal + residual)
527
+ try:
528
+ var_total = np.var(seasonal_clean + residual_clean)
529
+ if var_total > 0:
530
+ seasonal_strength = 1 - np.var(residual_clean) / var_total
531
+ return max(0.0, min(1.0, float(seasonal_strength)))
532
+ except:
533
+ pass
534
+
535
+ return 0.0
536
+
537
+ def _plot_decomposition(
538
+ self,
539
+ data: pd.DataFrame,
540
+ target_col: str,
541
+ decomposition: Dict,
542
+ method: str,
543
+ period: int
544
+ ) -> None:
545
+ """Visualise decomposition"""
546
+ fig, axes = plt.subplots(4, 1, figsize=(14, 12))
547
+
548
+ # Original series
549
+ axes[0].plot(decomposition.get('observed', pd.Series()))
550
+ axes[0].set_ylabel('Observed')
551
+ axes[0].set_title(f'Time Series Decomposition: {target_col} ({method}, period={period})')
552
+ axes[0].grid(True, alpha=0.3)
553
+
554
+ # Trend
555
+ if 'trend' in decomposition and decomposition['trend'] is not None:
556
+ axes[1].plot(decomposition['trend'])
557
+ axes[1].set_ylabel('Trend')
558
+ axes[1].grid(True, alpha=0.3)
559
+
560
+ # Seasonality
561
+ if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
562
+ axes[2].plot(decomposition['seasonal'])
563
+ axes[2].set_ylabel('Seasonality')
564
+ axes[2].grid(True, alpha=0.3)
565
+
566
+ # Residuals
567
+ if 'residual' in decomposition and decomposition['residual'] is not None:
568
+ axes[3].plot(decomposition['residual'])
569
+ axes[3].set_ylabel('Residuals')
570
+ axes[3].set_xlabel('Date')
571
+ axes[3].grid(True, alpha=0.3)
572
+
573
+ plt.tight_layout()
574
+ plt.savefig(
575
+ f'{self.config.results_dir}/plots/decomposition_{target_col}.png',
576
+ dpi=300,
577
+ bbox_inches='tight'
578
+ )
579
+ plt.show()
580
+
581
+ # Additional plots
582
+ self._plot_decomposition_components(data, target_col, decomposition)
583
+
584
+ def _plot_decomposition_components(
585
+ self,
586
+ data: pd.DataFrame,
587
+ target_col: str,
588
+ decomposition: Dict
589
+ ) -> None:
590
+ """Visualise decomposition components"""
591
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
592
+
593
+ # 1. Sum of components vs original series
594
+ if all(k in decomposition for k in ['trend', 'seasonal', 'residual']):
595
+ reconstructed = decomposition['trend'] + decomposition['seasonal'] + decomposition['residual']
596
+ axes[0, 0].plot(decomposition['observed'], alpha=0.7, label='Original')
597
+ axes[0, 0].plot(reconstructed, alpha=0.7, label='Reconstructed')
598
+ axes[0, 0].set_title('Original vs Reconstructed Series')
599
+ axes[0, 0].set_xlabel('Date')
600
+ axes[0, 0].set_ylabel(target_col)
601
+ axes[0, 0].legend()
602
+ axes[0, 0].grid(True, alpha=0.3)
603
+
604
+ # 2. Residuals distribution
605
+ if 'residual' in decomposition and decomposition['residual'] is not None:
606
+ residuals = decomposition['residual'].dropna()
607
+ axes[0, 1].hist(residuals, bins=30, edgecolor='black', alpha=0.7, density=True)
608
+
609
+ # Normal distribution for comparison
610
+ xmin, xmax = axes[0, 1].get_xlim()
611
+ x = np.linspace(xmin, xmax, 100)
612
+ p = stats.norm.pdf(x, residuals.mean(), residuals.std())
613
+ axes[0, 1].plot(x, p, 'k', linewidth=2, label='Normal distribution')
614
+
615
+ axes[0, 1].set_title('Residuals Distribution')
616
+ axes[0, 1].set_xlabel('Residuals')
617
+ axes[0, 1].set_ylabel('Density')
618
+ axes[0, 1].legend()
619
+ axes[0, 1].grid(True, alpha=0.3)
620
+
621
+ # 3. ACF of residuals
622
+ if 'residual' in decomposition and decomposition['residual'] is not None:
623
+ plot_acf(decomposition['residual'].dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
624
+ axes[1, 0].set_title('Residuals ACF')
625
+ axes[1, 0].set_xlabel('Lag')
626
+ axes[1, 0].set_ylabel('Autocorrelation')
627
+ axes[1, 0].grid(True, alpha=0.3)
628
+
629
+ # 4. Seasonal pattern
630
+ if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
631
+ seasonal = decomposition['seasonal']
632
+ if isinstance(seasonal.index, pd.DatetimeIndex):
633
+ # Group by months
634
+ try:
635
+ monthly_seasonal = seasonal.groupby(seasonal.index.month).mean()
636
+ axes[1, 1].bar(monthly_seasonal.index, monthly_seasonal.values)
637
+ axes[1, 1].set_title('Average Seasonal Pattern by Month')
638
+ axes[1, 1].set_xlabel('Month')
639
+ axes[1, 1].set_ylabel('Seasonality')
640
+ axes[1, 1].set_xticks(range(1, 13))
641
+ axes[1, 1].grid(True, alpha=0.3)
642
+ except:
643
+ axes[1, 1].plot(seasonal.index, seasonal.values)
644
+ axes[1, 1].set_title('Seasonal Component')
645
+ axes[1, 1].grid(True, alpha=0.3)
646
+
647
+ plt.tight_layout()
648
+ plt.savefig(
649
+ f'{self.config.results_dir}/plots/decomposition_components_{target_col}.png',
650
+ dpi=300,
651
+ bbox_inches='tight'
652
+ )
653
+ plt.show()
654
+
655
+ def _plot_residuals_analysis(self, residuals, target_col: str) -> None:
656
+ """Visualise residual analysis"""
657
+ if residuals is None:
658
+ return
659
+
660
+ residuals_clean = residuals.dropna()
661
+
662
+ if len(residuals_clean) == 0:
663
+ return
664
+
665
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
666
+
667
+ # Q-Q plot
668
+ stats.probplot(residuals_clean, dist="norm", plot=axes[0])
669
+ axes[0].set_title('Residuals Q-Q plot')
670
+ axes[0].grid(True, alpha=0.3)
671
+
672
+ # Residuals over time
673
+ axes[1].plot(residuals_clean.index, residuals_clean.values, linewidth=0.5)
674
+ axes[1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
675
+ axes[1].set_title('Residuals Over Time')
676
+ axes[1].set_xlabel('Date')
677
+ axes[1].set_ylabel('Residuals')
678
+ axes[1].grid(True, alpha=0.3)
679
+
680
+ plt.tight_layout()
681
+ plt.savefig(
682
+ f'{self.config.results_dir}/plots/residuals_analysis_{target_col}.png',
683
+ dpi=300,
684
+ bbox_inches='tight'
685
+ )
686
+ plt.show()
687
+
688
+ def get_report(self) -> Dict:
689
+ """Get decomposition report"""
690
+ return self.decomposition_results
feature_selection/__init__.py ADDED
File without changes
feature_selection/feature_selector.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 11: FEATURE SELECTION
3
+ # ============================================
4
+ from typing import Dict, List, Optional, Tuple
5
+ from venv import logger
6
+ from config.config import Config
7
+
8
+ try:
9
+ import pandas as pd
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from sklearn.ensemble import RandomForestRegressor
14
+ from sklearn.decomposition import PCA
15
+ from sklearn.preprocessing import StandardScaler
16
+ print("✅ All imports working!")
17
+ except ImportError as e:
18
+ print(f"❌ Import error: {e}")
19
+
20
+ from sklearn.inspection import permutation_importance, partial_dependence
21
+ from sklearn.feature_selection import (
22
+ SelectKBest, SelectPercentile, RFE, RFECV, VarianceThreshold,
23
+ f_regression, mutual_info_regression
24
+ )
25
+
26
+ class FeatureSelector:
27
+ """Class for selecting the most important features"""
28
+
29
+ def __init__(self, config: Config):
30
+ """
31
+ Initialise feature selector
32
+
33
+ Parameters:
34
+ -----------
35
+ config : Config
36
+ Experiment configuration
37
+ """
38
+ self.config = config
39
+ self.selected_features = []
40
+ self.feature_importances = {}
41
+ self.selection_methods = {}
42
+ self.selector_objects = {}
43
+
44
+ def select(
45
+ self,
46
+ data: pd.DataFrame,
47
+ target_col: Optional[str] = None,
48
+ method: str = None,
49
+ n_features: int = None,
50
+ **kwargs
51
+ ) -> pd.DataFrame:
52
+ """
53
+ Select the most important features
54
+
55
+ Parameters:
56
+ -----------
57
+ data : pd.DataFrame
58
+ Input data
59
+ target_col : str, optional
60
+ Target variable. If None, uses configuration value.
61
+ method : str, optional
62
+ Selection method. If None, uses configuration value.
63
+ n_features : int, optional
64
+ Number of features to select. If None, uses configuration value.
65
+ **kwargs : dict
66
+ Additional parameters for method
67
+
68
+ Returns:
69
+ --------
70
+ pd.DataFrame
71
+ Data with selected features
72
+ """
73
+ logger.info("\n" + "="*80)
74
+ logger.info("FEATURE SELECTION")
75
+ logger.info("="*80)
76
+
77
+ target_col = target_col or self.config.target_column
78
+ method = method or self.config.feature_selection_method
79
+ n_features = n_features or self.config.max_features
80
+
81
+ if target_col not in data.columns:
82
+ logger.error(f"Target variable '{target_col}' not found")
83
+ return data
84
+
85
+ # Prepare data
86
+ X = data.drop(columns=[target_col]).select_dtypes(include=[np.number])
87
+ y = data[target_col]
88
+
89
+ # Remove missing values
90
+ mask = X.notna().all(axis=1) & y.notna()
91
+ X_clean = X[mask]
92
+ y_clean = y[mask]
93
+
94
+ if len(X_clean) < 10 or len(X_clean.columns) < 2:
95
+ logger.warning("Insufficient data for feature selection")
96
+ return data
97
+
98
+ logger.info(f"Selection method: {method}")
99
+ logger.info(f"Target number of features: {n_features}")
100
+ logger.info(f"Initial number of features: {len(X.columns)}")
101
+ logger.info(f"Data for selection: {len(X_clean)} records")
102
+
103
+ # Apply selection method
104
+ selected_features_list = []
105
+ feature_importance_dict = {}
106
+
107
+ if method == 'correlation':
108
+ selected_features_list, feature_importance_dict = self._correlation_selection(
109
+ X_clean, y_clean, n_features, **kwargs
110
+ )
111
+
112
+ elif method == 'mutual_info':
113
+ selected_features_list, feature_importance_dict = self._mutual_info_selection(
114
+ X_clean, y_clean, n_features, **kwargs
115
+ )
116
+
117
+ elif method == 'rf':
118
+ selected_features_list, feature_importance_dict = self._random_forest_selection(
119
+ X_clean, y_clean, n_features, **kwargs
120
+ )
121
+
122
+ elif method == 'pca':
123
+ selected_features_list, feature_importance_dict = self._pca_selection(
124
+ X_clean, y_clean, n_features, **kwargs
125
+ )
126
+
127
+ elif method == 'rfe':
128
+ selected_features_list, feature_importance_dict = self._rfe_selection(
129
+ X_clean, y_clean, n_features, **kwargs
130
+ )
131
+
132
+ elif method == 'lasso':
133
+ selected_features_list, feature_importance_dict = self._lasso_selection(
134
+ X_clean, y_clean, n_features, **kwargs
135
+ )
136
+
137
+ elif method == 'hybrid':
138
+ selected_features_list, feature_importance_dict = self._hybrid_selection(
139
+ X_clean, y_clean, n_features, **kwargs
140
+ )
141
+
142
+ else:
143
+ logger.warning(f"Method {method} not supported, using correlation")
144
+ selected_features_list, feature_importance_dict = self._correlation_selection(
145
+ X_clean, y_clean, n_features, **kwargs
146
+ )
147
+
148
+ # Save selected features
149
+ self.selected_features = selected_features_list
150
+ self.feature_importances = feature_importance_dict
151
+ self.selection_methods[method] = {
152
+ 'selected_features': selected_features_list,
153
+ 'n_features': len(selected_features_list),
154
+ 'feature_importances': feature_importance_dict
155
+ }
156
+
157
+ # Form final dataset
158
+ features_to_keep = selected_features_list + [target_col]
159
+ features_to_keep = [f for f in features_to_keep if f in data.columns]
160
+
161
+ data_selected = data[features_to_keep].copy()
162
+
163
+ logger.info(f"✓ Selected {len(selected_features_list)} features")
164
+ logger.info(f" Total features kept: {len(data_selected.columns)}")
165
+
166
+ # Visualisation
167
+ if self.config.save_plots and selected_features_list:
168
+ self._plot_feature_selection(
169
+ X_clean, y_clean, selected_features_list,
170
+ feature_importance_dict, method
171
+ )
172
+
173
+ return data_selected
174
+
175
+ def _correlation_selection(
176
+ self,
177
+ X: pd.DataFrame,
178
+ y: pd.Series,
179
+ n_features: int,
180
+ **kwargs
181
+ ) -> Tuple[List[str], Dict]:
182
+ """Feature selection based on correlation"""
183
+ # Calculate correlations with target variable
184
+ correlations = X.corrwith(y).abs().sort_values(ascending=False)
185
+
186
+ # Select top-n_features
187
+ selected_features = correlations.head(n_features).index.tolist()
188
+ feature_importance = correlations.to_dict()
189
+
190
+ return selected_features, feature_importance
191
+
192
+ def _mutual_info_selection(
193
+ self,
194
+ X: pd.DataFrame,
195
+ y: pd.Series,
196
+ n_features: int,
197
+ **kwargs
198
+ ) -> Tuple[List[str], Dict]:
199
+ """Feature selection based on mutual information"""
200
+ try:
201
+ mi_scores = mutual_info_regression(X, y, random_state=kwargs.get('random_state', 42))
202
+ mi_series = pd.Series(mi_scores, index=X.columns)
203
+ mi_series = mi_series.sort_values(ascending=False)
204
+
205
+ selected_features = mi_series.head(n_features).index.tolist()
206
+ feature_importance = mi_series.to_dict()
207
+
208
+ return selected_features, feature_importance
209
+
210
+ except Exception as e:
211
+ logger.warning(f"Mutual information selection failed: {e}, using correlation")
212
+ return self._correlation_selection(X, y, n_features, **kwargs)
213
+
214
+ def _random_forest_selection(
215
+ self,
216
+ X: pd.DataFrame,
217
+ y: pd.Series,
218
+ n_features: int,
219
+ **kwargs
220
+ ) -> Tuple[List[str], Dict]:
221
+ """Feature selection based on Random Forest"""
222
+ try:
223
+ rf = RandomForestRegressor(
224
+ n_estimators=kwargs.get('n_estimators', 100),
225
+ max_depth=kwargs.get('max_depth', None),
226
+ random_state=kwargs.get('random_state', 42),
227
+ n_jobs=self.config.n_jobs if self.config.use_multiprocessing else None
228
+ )
229
+
230
+ rf.fit(X, y)
231
+ importances = pd.Series(rf.feature_importances_, index=X.columns)
232
+ importances = importances.sort_values(ascending=False)
233
+
234
+ selected_features = importances.head(n_features).index.tolist()
235
+ feature_importance = importances.to_dict()
236
+
237
+ self.selector_objects['random_forest'] = rf
238
+
239
+ return selected_features, feature_importance
240
+
241
+ except Exception as e:
242
+ logger.warning(f"Random Forest selection failed: {e}, using correlation")
243
+ return self._correlation_selection(X, y, n_features, **kwargs)
244
+
245
+ def _pca_selection(
246
+ self,
247
+ X: pd.DataFrame,
248
+ y: pd.Series,
249
+ n_features: int,
250
+ **kwargs
251
+ ) -> Tuple[List[str], Dict]:
252
+ """Feature selection based on PCA"""
253
+ try:
254
+ # First standardise data
255
+ from sklearn.preprocessing import StandardScaler
256
+
257
+ scaler = StandardScaler()
258
+ X_scaled = scaler.fit_transform(X)
259
+
260
+ # Apply PCA
261
+ pca = PCA(n_components=min(n_features, len(X.columns)))
262
+ X_pca = pca.fit_transform(X_scaled)
263
+
264
+ # Get feature importance via absolute component values
265
+ importance = np.abs(pca.components_).sum(axis=0)
266
+ importance_series = pd.Series(importance, index=X.columns)
267
+ importance_series = importance_series.sort_values(ascending=False)
268
+
269
+ selected_features = importance_series.head(n_features).index.tolist()
270
+ feature_importance = importance_series.to_dict()
271
+
272
+ self.selector_objects['pca'] = pca
273
+ self.selector_objects['scaler'] = scaler
274
+
275
+ return selected_features, feature_importance
276
+
277
+ except Exception as e:
278
+ logger.warning(f"PCA selection failed: {e}, using correlation")
279
+ return self._correlation_selection(X, y, n_features, **kwargs)
280
+
281
+ def _rfe_selection(
282
+ self,
283
+ X: pd.DataFrame,
284
+ y: pd.Series,
285
+ n_features: int,
286
+ **kwargs
287
+ ) -> Tuple[List[str], Dict]:
288
+ """Recursive Feature Elimination"""
289
+ try:
290
+ from sklearn.feature_selection import RFE
291
+ from sklearn.linear_model import LinearRegression
292
+
293
+ estimator = LinearRegression()
294
+ rfe = RFE(
295
+ estimator=estimator,
296
+ n_features_to_select=n_features,
297
+ step=kwargs.get('step', 1)
298
+ )
299
+
300
+ rfe.fit(X, y)
301
+ selected_mask = rfe.support_
302
+ selected_features = X.columns[selected_mask].tolist()
303
+
304
+ # Feature importance via ranking
305
+ ranking = pd.Series(rfe.ranking_, index=X.columns)
306
+ feature_importance = (1 / ranking).to_dict() # Convert ranking to importance
307
+
308
+ self.selector_objects['rfe'] = rfe
309
+
310
+ return selected_features, feature_importance
311
+
312
+ except Exception as e:
313
+ logger.warning(f"RFE selection failed: {e}, using correlation")
314
+ return self._correlation_selection(X, y, n_features, **kwargs)
315
+
316
+ def _lasso_selection(
317
+ self,
318
+ X: pd.DataFrame,
319
+ y: pd.Series,
320
+ n_features: int,
321
+ **kwargs
322
+ ) -> Tuple[List[str], Dict]:
323
+ """Feature selection using Lasso"""
324
+ try:
325
+ from sklearn.linear_model import LassoCV
326
+
327
+ lasso = LassoCV(
328
+ cv=kwargs.get('cv', 5),
329
+ random_state=kwargs.get('random_state', 42),
330
+ max_iter=kwargs.get('max_iter', 1000)
331
+ )
332
+
333
+ lasso.fit(X, y)
334
+
335
+ # Features with non-zero coefficients
336
+ coefficients = pd.Series(lasso.coef_, index=X.columns)
337
+ non_zero_features = coefficients[coefficients != 0].abs().sort_values(ascending=False)
338
+
339
+ # Select top-n_features
340
+ selected_features = non_zero_features.head(n_features).index.tolist()
341
+ feature_importance = non_zero_features.to_dict()
342
+
343
+ self.selector_objects['lasso'] = lasso
344
+
345
+ return selected_features, feature_importance
346
+
347
+ except Exception as e:
348
+ logger.warning(f"Lasso selection failed: {e}, using correlation")
349
+ return self._correlation_selection(X, y, n_features, **kwargs)
350
+
351
+ def _hybrid_selection(
352
+ self,
353
+ X: pd.DataFrame,
354
+ y: pd.Series,
355
+ n_features: int,
356
+ **kwargs
357
+ ) -> Tuple[List[str], Dict]:
358
+ """Hybrid feature selection method"""
359
+ # Combine multiple methods
360
+ methods = kwargs.get('methods', ['correlation', 'mutual_info', 'rf'])
361
+ weights = kwargs.get('weights', [0.3, 0.3, 0.4])
362
+
363
+ all_importances = {}
364
+
365
+ for method, weight in zip(methods, weights):
366
+ try:
367
+ if method == 'correlation':
368
+ _, importance = self._correlation_selection(X, y, n_features, **kwargs)
369
+ elif method == 'mutual_info':
370
+ _, importance = self._mutual_info_selection(X, y, n_features, **kwargs)
371
+ elif method == 'rf':
372
+ _, importance = self._random_forest_selection(X, y, n_features, **kwargs)
373
+ else:
374
+ continue
375
+
376
+ # Normalise importances and weight them
377
+ importance_series = pd.Series(importance)
378
+ if importance_series.max() > importance_series.min():
379
+ importance_normalized = (importance_series - importance_series.min()) / \
380
+ (importance_series.max() - importance_series.min())
381
+ else:
382
+ importance_normalized = pd.Series(1, index=importance_series.index)
383
+
384
+ # Add weighted importances
385
+ for feature in importance_normalized.index:
386
+ if feature not in all_importances:
387
+ all_importances[feature] = 0
388
+ all_importances[feature] += importance_normalized[feature] * weight
389
+
390
+ except Exception as e:
391
+ logger.debug(f"Method {method} failed in hybrid selection: {e}")
392
+
393
+ # Sort by total importance
394
+ combined_importance = pd.Series(all_importances).sort_values(ascending=False)
395
+ selected_features = combined_importance.head(n_features).index.tolist()
396
+
397
+ return selected_features, combined_importance.to_dict()
398
+
399
+ def _plot_feature_selection(
400
+ self,
401
+ X: pd.DataFrame,
402
+ y: pd.Series,
403
+ selected_features: List[str],
404
+ feature_importance: Dict,
405
+ method: str
406
+ ) -> None:
407
+ """Visualise feature selection results"""
408
+ # Prepare data for visualisation
409
+ importance_series = pd.Series(feature_importance).sort_values(ascending=False)
410
+
411
+ # Limit number of features for display
412
+ display_features = importance_series.head(20)
413
+
414
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
415
+
416
+ # 1. Feature importance
417
+ y_pos = np.arange(len(display_features))
418
+ axes[0, 0].barh(y_pos, display_features.values)
419
+ axes[0, 0].set_yticks(y_pos)
420
+ axes[0, 0].set_yticklabels(display_features.index, fontsize=9)
421
+ axes[0, 0].invert_yaxis()
422
+ axes[0, 0].set_xlabel('Importance')
423
+ axes[0, 0].set_title(f'Top-{len(display_features)} features by importance ({method})')
424
+ axes[0, 0].grid(True, alpha=0.3, axis='x')
425
+
426
+ # 2. Cumulative importance
427
+ cumulative_importance = importance_series.cumsum() / importance_series.sum()
428
+ axes[0, 1].plot(range(1, len(cumulative_importance) + 1), cumulative_importance.values)
429
+ axes[0, 1].axhline(y=0.8, color='r', linestyle='--', alpha=0.7, label='80% importance')
430
+ axes[0, 1].axhline(y=0.9, color='orange', linestyle='--', alpha=0.7, label='90% importance')
431
+ axes[0, 1].set_xlabel('Number of features')
432
+ axes[0, 1].set_ylabel('Cumulative importance')
433
+ axes[0, 1].set_title('Cumulative feature importance')
434
+ axes[0, 1].legend()
435
+ axes[0, 1].grid(True, alpha=0.3)
436
+
437
+ # 3. Correlation matrix of selected features
438
+ if len(selected_features) > 1:
439
+ selected_X = X[selected_features]
440
+ corr_matrix = selected_X.corr()
441
+
442
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
443
+ sns.heatmap(
444
+ corr_matrix,
445
+ annot=True,
446
+ fmt='.2f',
447
+ cmap='coolwarm',
448
+ center=0,
449
+ square=True,
450
+ mask=mask,
451
+ cbar_kws={'shrink': 0.8},
452
+ ax=axes[1, 0]
453
+ )
454
+ axes[1, 0].set_title(f'Correlation of selected features ({len(selected_features)})')
455
+
456
+ # 4. Importance distribution
457
+ axes[1, 1].hist(importance_series.values, bins=30, edgecolor='black', alpha=0.7)
458
+ axes[1, 1].set_xlabel('Feature importance')
459
+ axes[1, 1].set_ylabel('Frequency')
460
+ axes[1, 1].set_title('Feature importance distribution')
461
+ axes[1, 1].grid(True, alpha=0.3)
462
+
463
+ plt.suptitle(f'Feature selection results using {method} method', fontsize=14)
464
+ plt.tight_layout()
465
+ plt.savefig(
466
+ f'{self.config.results_dir}/plots/feature_selection_{method}.png',
467
+ dpi=300,
468
+ bbox_inches='tight'
469
+ )
470
+ plt.show()
471
+
472
+ def get_report(self) -> Dict:
473
+ """Get feature selection report"""
474
+ return {
475
+ 'selected_features': self.selected_features,
476
+ 'feature_importances': self.feature_importances,
477
+ 'selection_methods': self.selection_methods
478
+ }
features/__init__.py ADDED
File without changes
features/feature_engineer.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 5: FEATURE ENGINEER
3
+ # ============================================
4
+ from typing import Dict, List, Optional
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+
13
+ class FeatureEngineer:
14
+ """Class for creating new features for time series"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise feature engineer
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.created_features = []
27
+ self.feature_info = {}
28
+ self.feature_importances = {}
29
+ self.transforms_applied = {}
30
+
31
+ def create_all_features(
32
+ self,
33
+ data: pd.DataFrame,
34
+ target_col: Optional[str] = None
35
+ ) -> pd.DataFrame:
36
+ """
37
+ Create all types of features
38
+
39
+ Parameters:
40
+ -----------
41
+ data : pd.DataFrame
42
+ Input data
43
+ target_col : str, optional
44
+ Target variable. If None, uses configuration value.
45
+
46
+ Returns:
47
+ --------
48
+ pd.DataFrame
49
+ Data with all features
50
+ """
51
+ logger.info("\n" + "="*80)
52
+ logger.info("CREATING FEATURES FOR TIME SERIES")
53
+ logger.info("="*80)
54
+
55
+ target_col = target_col or self.config.target_column
56
+ initial_features = len(data.columns)
57
+ initial_rows = len(data)
58
+
59
+ # Check and save index
60
+ original_index = data.index
61
+ index_is_datetime = isinstance(original_index, pd.DatetimeIndex)
62
+
63
+ logger.info(f"Initial number of features: {initial_features}")
64
+ logger.info(f"Initial number of rows: {initial_rows}")
65
+ logger.info(f"Index is DatetimeIndex: {index_is_datetime}")
66
+
67
+ # If index not DatetimeIndex but 'date' column exists
68
+ if not index_is_datetime and 'date' in data.columns:
69
+ logger.info("Attempting to set DatetimeIndex from 'date' column")
70
+ try:
71
+ data = data.set_index('date')
72
+ if isinstance(data.index, pd.DatetimeIndex):
73
+ index_is_datetime = True
74
+ original_index = data.index
75
+ logger.info("✓ DatetimeIndex set from 'date' column")
76
+ else:
77
+ logger.warning("Failed to set DatetimeIndex")
78
+ except Exception as e:
79
+ logger.warning(f"Error setting DatetimeIndex: {e}")
80
+
81
+ # Save data copy for index restoration later
82
+ data_processed = data.copy()
83
+
84
+ # 1. Create basic temporal features (if date exists)
85
+ if index_is_datetime:
86
+ logger.info("\n1. BASIC TEMPORAL FEATURES")
87
+ data_processed = self.create_temporal_features(data_processed)
88
+ else:
89
+ logger.info("\n1. BASIC TEMPORAL FEATURES: skipped (no DatetimeIndex)")
90
+
91
+ # 2. Create statistical features
92
+ logger.info("\n2. STATISTICAL FEATURES")
93
+ data_processed = self.create_statistical_features(data_processed, target_col)
94
+
95
+ # 3. Create rolling features
96
+ logger.info("\n3. ROLLING FEATURES")
97
+ data_processed = self.create_rolling_features(data_processed, target_col)
98
+
99
+ # 4. Create lag features (limited quantity)
100
+ logger.info("\n4. LAG FEATURES")
101
+ data_processed = self.create_lag_features(data_processed, target_col)
102
+
103
+ # 5. Create interaction features
104
+ logger.info("\n5. INTERACTION FEATURES")
105
+ data_processed = self.create_interaction_features(data_processed, target_col)
106
+
107
+ # 6. Create spectral features (only if sufficient data)
108
+ logger.info("\n6. SPECTRAL FEATURES")
109
+ if len(data_processed) > 100:
110
+ data_processed = self.create_spectral_features(data_processed, target_col)
111
+ else:
112
+ logger.info(" Skipped: insufficient data")
113
+
114
+ # 7. Create decomposition features (only if sufficient data and date exists)
115
+ logger.info("\n7. DECOMPOSITION FEATURES")
116
+ if len(data_processed) > 365 and index_is_datetime:
117
+ data_processed = self.create_decomposition_features(data_processed, target_col)
118
+ else:
119
+ logger.info(" Skipped: insufficient data or no DatetimeIndex")
120
+
121
+ # Remove rows with NaN that appeared due to lags and differences
122
+ rows_before_nan = len(data_processed)
123
+ data_processed = data_processed.dropna()
124
+ rows_after_nan = len(data_processed)
125
+ removed_rows = rows_before_nan - rows_after_nan
126
+
127
+ # Remove constant features
128
+ constant_features = []
129
+ for col in data_processed.columns:
130
+ if data_processed[col].nunique() <= 1:
131
+ constant_features.append(col)
132
+
133
+ if constant_features:
134
+ logger.info(f"\nRemoving constant features: {len(constant_features)} found")
135
+ for feat in constant_features[:10]:
136
+ logger.info(f" - {feat}")
137
+ if len(constant_features) > 10:
138
+ logger.info(f" ... and {len(constant_features) - 10} more features")
139
+
140
+ data_processed = data_processed.drop(columns=constant_features)
141
+ # Update created features list
142
+ self.created_features = [f for f in self.created_features if f not in constant_features]
143
+
144
+ # Save information
145
+ self.feature_info = {
146
+ 'initial_features': initial_features,
147
+ 'final_features': len(data_processed.columns),
148
+ 'features_created': len(self.created_features),
149
+ 'initial_rows': initial_rows,
150
+ 'final_rows': len(data_processed),
151
+ 'removed_rows': removed_rows,
152
+ 'constant_features_removed': len(constant_features),
153
+ 'created_features_list': self.created_features,
154
+ 'feature_categories': self.get_feature_categories()
155
+ }
156
+
157
+ logger.info(f"\nFeature creation summary:")
158
+ logger.info(f" Initial number of features: {initial_features}")
159
+ logger.info(f" Final number of features: {len(data_processed.columns)}")
160
+ logger.info(f" New features created: {len(self.created_features)}")
161
+ logger.info(f" Initial number of rows: {initial_rows}")
162
+ logger.info(f" Final number of rows: {len(data_processed)}")
163
+ logger.info(f" Rows removed due to NaN: {removed_rows}")
164
+ logger.info(f" Constant features removed: {len(constant_features)}")
165
+
166
+ return data_processed
167
+
168
+ def create_temporal_features(self, data: pd.DataFrame) -> pd.DataFrame:
169
+ """
170
+ Create temporal features
171
+
172
+ Parameters:
173
+ -----------
174
+ data : pd.DataFrame
175
+ Input data
176
+
177
+ Returns:
178
+ --------
179
+ pd.DataFrame
180
+ Data with temporal features
181
+ """
182
+ data_processed = data.copy()
183
+
184
+ if not isinstance(data_processed.index, pd.DatetimeIndex):
185
+ logger.warning("Temporal features not created: index not DatetimeIndex")
186
+ return data_processed
187
+
188
+ try:
189
+ # Basic temporal features
190
+ data_processed['year'] = data_processed.index.year
191
+ data_processed['month'] = data_processed.index.month
192
+ data_processed['day'] = data_processed.index.day
193
+ data_processed['dayofyear'] = data_processed.index.dayofyear
194
+ data_processed['dayofweek'] = data_processed.index.dayofweek
195
+ data_processed['weekofyear'] = data_processed.index.isocalendar().week.astype(int)
196
+ data_processed['quarter'] = data_processed.index.quarter
197
+ data_processed['is_weekend'] = data_processed['dayofweek'].isin([5, 6]).astype(int)
198
+
199
+ # Cyclic features for seasonality
200
+ data_processed['month_sin'] = np.sin(2 * np.pi * data_processed['month'] / 12)
201
+ data_processed['month_cos'] = np.cos(2 * np.pi * data_processed['month'] / 12)
202
+ data_processed['dayofyear_sin'] = np.sin(2 * np.pi * data_processed['dayofyear'] / 365.25)
203
+ data_processed['dayofyear_cos'] = np.cos(2 * np.pi * data_processed['dayofyear'] / 365.25)
204
+ data_processed['dayofweek_sin'] = np.sin(2 * np.pi * data_processed['dayofweek'] / 7)
205
+ data_processed['dayofweek_cos'] = np.cos(2 * np.pi * data_processed['dayofweek'] / 7)
206
+
207
+ # Time in days from start (relative features)
208
+ min_date = data_processed.index.min()
209
+ data_processed['days_from_start'] = (data_processed.index - min_date).days
210
+
211
+ # Register created features
212
+ temporal_features = ['year', 'month', 'day', 'dayofyear', 'dayofweek',
213
+ 'weekofyear', 'quarter', 'is_weekend', 'month_sin',
214
+ 'month_cos', 'dayofyear_sin', 'dayofyear_cos',
215
+ 'dayofweek_sin', 'dayofweek_cos', 'days_from_start']
216
+
217
+ self.created_features.extend([f for f in temporal_features if f not in self.created_features])
218
+
219
+ logger.info(f"✓ Created {len(temporal_features)} temporal features")
220
+
221
+ except Exception as e:
222
+ logger.warning(f"Error creating temporal features: {e}")
223
+
224
+ return data_processed
225
+
226
+ def create_statistical_features(
227
+ self,
228
+ data: pd.DataFrame,
229
+ target_col: str
230
+ ) -> pd.DataFrame:
231
+ """
232
+ Create statistical features
233
+
234
+ Parameters:
235
+ -----------
236
+ data : pd.DataFrame
237
+ Input data
238
+ target_col : str
239
+ Target variable
240
+
241
+ Returns:
242
+ --------
243
+ pd.DataFrame
244
+ Data with statistical features
245
+ """
246
+ data_processed = data.copy()
247
+
248
+ if target_col not in data_processed.columns:
249
+ logger.warning(f"Target variable '{target_col}' not found")
250
+ return data_processed
251
+
252
+ # Only if we have year data
253
+ if 'year' in data_processed.columns:
254
+ # Yearly statistics
255
+ try:
256
+ yearly_stats = data_processed.groupby('year')[target_col].agg([
257
+ 'mean', 'std', 'min', 'max', 'median'
258
+ ])
259
+ yearly_stats.columns = [f'{target_col}_yearly_{col}' for col in yearly_stats.columns]
260
+ data_processed = data_processed.merge(yearly_stats, on='year', how='left')
261
+
262
+ # Add created features to list
263
+ for col in yearly_stats.columns:
264
+ self.created_features.append(col)
265
+ except Exception as e:
266
+ logger.debug(f"Yearly statistics not created: {e}")
267
+
268
+ # Normalised features (only if there is variation)
269
+ std_val = data_processed[target_col].std()
270
+ if std_val > 0:
271
+ data_processed[f'{target_col}_zscore'] = (data_processed[target_col] - data_processed[target_col].mean()) / std_val
272
+ self.created_features.append(f'{target_col}_zscore')
273
+
274
+ # Features based on percentiles (binary features)
275
+ try:
276
+ for p in [0.25, 0.5, 0.75]:
277
+ quantile_val = data_processed[target_col].quantile(p)
278
+ data_processed[f'{target_col}_above_p{int(p*100)}'] = (data_processed[target_col] > quantile_val).astype(int)
279
+ self.created_features.append(f'{target_col}_above_p{int(p*100)}')
280
+ except Exception as e:
281
+ logger.debug(f"Quantile features not created: {e}")
282
+
283
+ logger.info(f"✓ Statistical features created: {len([c for c in data_processed.columns if c not in data.columns])}")
284
+ return data_processed
285
+
286
+ def create_rolling_features(
287
+ self,
288
+ data: pd.DataFrame,
289
+ target_col: str
290
+ ) -> pd.DataFrame:
291
+ """
292
+ Create rolling statistics
293
+
294
+ Parameters:
295
+ -----------
296
+ data : pd.DataFrame
297
+ Input data
298
+ target_col : str
299
+ Target variable
300
+
301
+ Returns:
302
+ --------
303
+ pd.DataFrame
304
+ Data with rolling features
305
+ """
306
+ data_processed = data.copy()
307
+
308
+ if target_col not in data_processed.columns:
309
+ logger.warning(f"Target variable '{target_col}' not found")
310
+ return data_processed
311
+
312
+ # Use only main windows from configuration
313
+ windows = [w for w in self.config.rolling_windows if w < len(data_processed) // 2]
314
+
315
+ for window in windows:
316
+ try:
317
+ # Basic statistics
318
+ data_processed[f'{target_col}_rolling_mean_{window}'] = data_processed[target_col].rolling(
319
+ window=window, min_periods=max(1, window//4), center=True
320
+ ).mean()
321
+
322
+ data_processed[f'{target_col}_rolling_std_{window}'] = data_processed[target_col].rolling(
323
+ window=window, min_periods=max(1, window//4), center=True
324
+ ).std()
325
+
326
+ data_processed[f'{target_col}_rolling_min_{window}'] = data_processed[target_col].rolling(
327
+ window=window, min_periods=max(1, window//4), center=True
328
+ ).min()
329
+
330
+ data_processed[f'{target_col}_rolling_max_{window}'] = data_processed[target_col].rolling(
331
+ window=window, min_periods=max(1, window//4), center=True
332
+ ).max()
333
+
334
+ self.created_features.extend([
335
+ f'{target_col}_rolling_mean_{window}',
336
+ f'{target_col}_rolling_std_{window}',
337
+ f'{target_col}_rolling_min_{window}',
338
+ f'{target_col}_rolling_max_{window}'
339
+ ])
340
+ except Exception as e:
341
+ logger.debug(f"Rolling features for window {window} not created: {e}")
342
+ continue
343
+
344
+ logger.info(f"✓ Rolling features created: {len([c for c in data_processed.columns if 'rolling' in c and c not in data.columns])}")
345
+ return data_processed
346
+
347
+ def create_lag_features(
348
+ self,
349
+ data: pd.DataFrame,
350
+ target_col: str
351
+ ) -> pd.DataFrame:
352
+ """
353
+ Create lag features
354
+
355
+ Parameters:
356
+ -----------
357
+ data : pd.DataFrame
358
+ Input data
359
+ target_col : str
360
+ Target variable
361
+
362
+ Returns:
363
+ --------
364
+ pd.DataFrame
365
+ Data with lag features
366
+ """
367
+ data_processed = data.copy()
368
+
369
+ if target_col not in data_processed.columns:
370
+ logger.warning(f"Target variable '{target_col}' not found")
371
+ return data_processed
372
+
373
+ # Limited number of lags
374
+ max_lags = min(self.config.max_lags, 7) # Maximum 7 lags
375
+
376
+ for lag in [1, 2, 3, 7, 14, 30]:
377
+ if lag <= max_lags:
378
+ data_processed[f'{target_col}_lag_{lag}'] = data_processed[target_col].shift(lag)
379
+ self.created_features.append(f'{target_col}_lag_{lag}')
380
+
381
+ # Seasonal lags (only if sufficient data)
382
+ if len(data_processed) > 365:
383
+ try:
384
+ data_processed[f'{target_col}_seasonal_lag_365'] = data_processed[target_col].shift(365)
385
+ self.created_features.append(f'{target_col}_seasonal_lag_365')
386
+ except Exception as e:
387
+ logger.debug(f"Seasonal lag not created: {e}")
388
+
389
+ # Differences (stationarity)
390
+ data_processed[f'{target_col}_diff_1'] = data_processed[target_col].diff(1)
391
+ self.created_features.append(f'{target_col}_diff_1')
392
+
393
+ if len(data_processed) > 7:
394
+ data_processed[f'{target_col}_diff_7'] = data_processed[target_col].diff(7)
395
+ self.created_features.append(f'{target_col}_diff_7')
396
+
397
+ logger.info(f"✓ Lag features created: {len([c for c in data_processed.columns if ('lag' in c or 'diff' in c) and c not in data.columns])}")
398
+ return data_processed
399
+
400
+ def create_interaction_features(
401
+ self,
402
+ data: pd.DataFrame,
403
+ target_col: str
404
+ ) -> pd.DataFrame:
405
+ """
406
+ Create interaction features
407
+
408
+ Parameters:
409
+ -----------
410
+ data : pd.DataFrame
411
+ Input data
412
+ target_col : str
413
+ Target variable
414
+
415
+ Returns:
416
+ --------
417
+ pd.DataFrame
418
+ Data with interaction features
419
+ """
420
+ data_processed = data.copy()
421
+
422
+ if target_col not in data_processed.columns:
423
+ logger.warning(f"Target variable '{target_col}' not found")
424
+ return data_processed
425
+
426
+ # Interactions with temperature (only if data exists)
427
+ temp_cols = ['tavg', 'tmin', 'tmax']
428
+ available_temp_cols = [col for col in temp_cols if col in data_processed.columns]
429
+
430
+ for temp_col in available_temp_cols:
431
+ try:
432
+ # Avoid division by zero
433
+ temp_data = data_processed[temp_col].replace(0, np.nan)
434
+ if temp_data.notna().all() and (temp_data != 0).all():
435
+ data_processed[f'{target_col}_{temp_col}_ratio'] = data_processed[target_col] / temp_data
436
+ self.created_features.append(f'{target_col}_{temp_col}_ratio')
437
+
438
+ # Product
439
+ data_processed[f'{target_col}_{temp_col}_product'] = data_processed[target_col] * temp_data
440
+ self.created_features.append(f'{target_col}_{temp_col}_product')
441
+ except Exception as e:
442
+ logger.debug(f"Interaction feature with {temp_col} not created: {e}")
443
+
444
+ # Interaction with water level
445
+ if 'urovenvoda' in data_processed.columns:
446
+ try:
447
+ uroven_data = data_processed['urovenvoda'].replace(0, np.nan)
448
+ if uroven_data.notna().all() and (uroven_data != 0).all():
449
+ data_processed[f'{target_col}_urovenvoda_ratio'] = data_processed[target_col] / uroven_data
450
+ self.created_features.append(f'{target_col}_urovenvoda_ratio')
451
+ except Exception as e:
452
+ logger.debug(f"Interaction feature with urovenvoda not created: {e}")
453
+
454
+ logger.info(f"✓ Interaction features created: {len([c for c in data_processed.columns if ('ratio' in c or 'product' in c) and c not in data.columns])}")
455
+ return data_processed
456
+
457
+ def create_spectral_features(
458
+ self,
459
+ data: pd.DataFrame,
460
+ target_col: str
461
+ ) -> pd.DataFrame:
462
+ """
463
+ Create spectral features
464
+
465
+ Parameters:
466
+ -----------
467
+ data : pd.DataFrame
468
+ Input data
469
+ target_col : str
470
+ Target variable
471
+
472
+ Returns:
473
+ --------
474
+ pd.DataFrame
475
+ Data with spectral features
476
+ """
477
+ data_processed = data.copy()
478
+
479
+ if target_col not in data_processed.columns:
480
+ logger.warning(f"Target variable '{target_col}' not found")
481
+ return data_processed
482
+
483
+ if len(data_processed) < 100:
484
+ logger.info("Insufficient data for creating spectral features")
485
+ return data_processed
486
+
487
+ try:
488
+ # Fast Fourier Transform
489
+ series = data_processed[target_col].dropna().values
490
+
491
+ if len(series) > 50:
492
+ # Calculate periodogram
493
+ from scipy.signal import periodogram
494
+ freqs, psd = periodogram(series, fs=1.0)
495
+
496
+ # Find dominant frequencies
497
+ if len(psd) > 3:
498
+ # Top-3 frequencies by power
499
+ top_indices = np.argsort(psd)[-3:][::-1]
500
+
501
+ for i, idx in enumerate(top_indices, 1):
502
+ if idx < len(freqs):
503
+ freq = freqs[idx]
504
+ if freq > 0:
505
+ period = 1 / freq
506
+ data_processed[f'{target_col}_dominant_period_{i}'] = period
507
+ self.created_features.append(f'{target_col}_dominant_period_{i}')
508
+
509
+ except Exception as e:
510
+ logger.debug(f"Spectral features creation failed: {e}")
511
+
512
+ return data_processed
513
+
514
+ def create_decomposition_features(
515
+ self,
516
+ data: pd.DataFrame,
517
+ target_col: str
518
+ ) -> pd.DataFrame:
519
+ """
520
+ Create features based on decomposition
521
+
522
+ Parameters:
523
+ -----------
524
+ data : pd.DataFrame
525
+ Input data
526
+ target_col : str
527
+ Target variable
528
+
529
+ Returns:
530
+ --------
531
+ pd.DataFrame
532
+ Data with decomposition features
533
+ """
534
+ data_processed = data.copy()
535
+
536
+ if target_col not in data_processed.columns:
537
+ logger.warning(f"Target variable '{target_col}' not found")
538
+ return data_processed
539
+
540
+ if len(data_processed) < 365:
541
+ logger.info("Insufficient data for decomposition")
542
+ return data_processed
543
+
544
+ try:
545
+ # Check for date presence
546
+ if isinstance(data_processed.index, pd.DatetimeIndex):
547
+ # STL decomposition
548
+ if len(data_processed) > 730: # Need at least 2 years for yearly seasonality
549
+ try:
550
+ from statsmodels.tsa.seasonal import STL
551
+
552
+ # STL decomposition
553
+ stl = STL(
554
+ data_processed[target_col].fillna(method='ffill'),
555
+ period=365,
556
+ robust=True
557
+ )
558
+ result = stl.fit()
559
+
560
+ # Add components
561
+ data_processed[f'{target_col}_trend'] = result.trend
562
+ data_processed[f'{target_col}_seasonal'] = result.seasonal
563
+ data_processed[f'{target_col}_residual'] = result.resid
564
+
565
+ self.created_features.extend([
566
+ f'{target_col}_trend',
567
+ f'{target_col}_seasonal',
568
+ f'{target_col}_residual'
569
+ ])
570
+
571
+ logger.info("✓ STL decomposition successful")
572
+
573
+ except Exception as e:
574
+ logger.debug(f"STL decomposition failed: {e}")
575
+ # Simple seasonal decomposition
576
+ try:
577
+ from statsmodels.tsa.seasonal import seasonal_decompose
578
+
579
+ decomposition = seasonal_decompose(
580
+ data_processed[target_col].fillna(method='ffill'),
581
+ model='additive',
582
+ period=365,
583
+ extrapolate_trend='freq'
584
+ )
585
+
586
+ data_processed[f'{target_col}_trend'] = decomposition.trend
587
+ data_processed[f'{target_col}_seasonal'] = decomposition.seasonal
588
+
589
+ self.created_features.extend([
590
+ f'{target_col}_trend',
591
+ f'{target_col}_seasonal'
592
+ ])
593
+
594
+ logger.info("✓ Seasonal decomposition successful")
595
+ except Exception as e2:
596
+ logger.debug(f"Seasonal decomposition failed: {e2}")
597
+
598
+ except Exception as e:
599
+ logger.debug(f"Decomposition features creation failed: {e}")
600
+
601
+ return data_processed
602
+
603
+ def get_feature_categories(self) -> Dict[str, List[str]]:
604
+ """Get features by categories"""
605
+ categories = {
606
+ 'temporal': [],
607
+ 'statistical': [],
608
+ 'rolling': [],
609
+ 'lag': [],
610
+ 'interaction': [],
611
+ 'spectral': [],
612
+ 'decomposition': [],
613
+ 'binary': []
614
+ }
615
+
616
+ for feature in self.created_features:
617
+ if any(keyword in feature for keyword in ['year', 'month', 'day', 'week', 'quarter', 'sin', 'cos', 'is_weekend']):
618
+ categories['temporal'].append(feature)
619
+ elif any(keyword in feature for keyword in ['zscore', 'above_p', 'yearly_']):
620
+ if 'above_p' in feature:
621
+ categories['binary'].append(feature)
622
+ else:
623
+ categories['statistical'].append(feature)
624
+ elif 'rolling' in feature:
625
+ categories['rolling'].append(feature)
626
+ elif any(keyword in feature for keyword in ['lag', 'diff']):
627
+ categories['lag'].append(feature)
628
+ elif 'ratio' in feature or 'product' in feature:
629
+ categories['interaction'].append(feature)
630
+ elif 'dominant' in feature:
631
+ categories['spectral'].append(feature)
632
+ elif any(keyword in feature for keyword in ['trend', 'seasonal', 'residual']):
633
+ categories['decomposition'].append(feature)
634
+
635
+ # Remove empty categories
636
+ categories = {k: v for k, v in categories.items() if v}
637
+
638
+ return categories
missing_values/__init__.py ADDED
File without changes
missing_values/missing_analyzer.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 3: MISSING VALUE ANALYSER
3
+ # ============================================
4
+ from typing import Dict, Tuple
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+ from scipy.interpolate import interp1d
9
+ from statsmodels.tsa.seasonal import seasonal_decompose, STL
10
+ try:
11
+ import pandas as pd
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ print("✅ All imports working!")
15
+ except ImportError as e:
16
+ print(f"❌ Import error: {e}")
17
+
18
+ class MissingValueAnalyser:
19
+ """Class for analysing and handling missing values"""
20
+
21
+ def __init__(self, config: Config):
22
+ """
23
+ Initialise missing value analyser
24
+
25
+ Parameters:
26
+ -----------
27
+ config : Config
28
+ Experiment configuration
29
+ """
30
+ self.config = config
31
+ self.missing_info = {}
32
+ self.handling_methods = {}
33
+ self.imputers = {}
34
+ self.missing_patterns = {}
35
+
36
+ def analyse(
37
+ self,
38
+ data: pd.DataFrame,
39
+ detailed: bool = True
40
+ ) -> Dict:
41
+ """
42
+ Analyse missing values in data
43
+
44
+ Parameters:
45
+ -----------
46
+ data : pd.DataFrame
47
+ Input data
48
+ detailed : bool
49
+ Whether to perform detailed analysis
50
+
51
+ Returns:
52
+ --------
53
+ Dict
54
+ Information about missing values
55
+ """
56
+ logger.info("\n" + "="*80)
57
+ logger.info("MISSING VALUE ANALYSIS")
58
+ logger.info("="*80)
59
+
60
+ # Calculate missing values
61
+ missing_total = data.isnull().sum()
62
+ missing_percent = (missing_total / len(data)) * 100
63
+
64
+ missing_df = pd.DataFrame({
65
+ 'missing_count': missing_total,
66
+ 'missing_percent': missing_percent,
67
+ 'dtype': data.dtypes.astype(str)
68
+ })
69
+
70
+ # Detailed analysis
71
+ if detailed:
72
+ self._detailed_missing_analysis(data, missing_df)
73
+
74
+ # Save information
75
+ self.missing_info = {
76
+ 'summary': {
77
+ col: {
78
+ 'missing_count': int(missing_df.loc[col, 'missing_count']),
79
+ 'missing_percent': float(missing_df.loc[col, 'missing_percent']),
80
+ 'dtype': missing_df.loc[col, 'dtype']
81
+ }
82
+ for col in missing_df.index
83
+ },
84
+ 'overall': {
85
+ 'total_missing': int(missing_total.sum()),
86
+ 'total_rows': int(len(data)),
87
+ 'total_cells': int(data.size),
88
+ 'overall_missing_percentage': float(missing_total.sum() / data.size * 100),
89
+ 'rows_with_any_missing': int(data.isnull().any(axis=1).sum()),
90
+ 'rows_all_missing': int(data.isnull().all(axis=1).sum()),
91
+ 'columns_with_missing': missing_df[missing_df['missing_count'] > 0].index.tolist(),
92
+ 'columns_all_missing': missing_df[missing_df['missing_count'] == len(data)].index.tolist()
93
+ }
94
+ }
95
+
96
+ # Visualisation
97
+ if self.config.save_plots:
98
+ self._plot_missing_values(data, missing_df)
99
+
100
+ # Output results
101
+ self._log_missing_summary(missing_df)
102
+
103
+ return self.missing_info
104
+
105
+ def _detailed_missing_analysis(
106
+ self,
107
+ data: pd.DataFrame,
108
+ missing_df: pd.DataFrame
109
+ ) -> None:
110
+ """Detailed missing value analysis"""
111
+ # Analyse missing patterns
112
+ missing_matrix = data.isnull()
113
+
114
+ # Row missing patterns
115
+ row_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=1)
116
+ row_pattern_counts = row_patterns.value_counts().head(10)
117
+
118
+ # Column missing patterns
119
+ col_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=0)
120
+ col_pattern_counts = col_patterns.value_counts().head(10)
121
+
122
+ # Time-based missing patterns analysis
123
+ time_patterns = {}
124
+ if isinstance(data.index, pd.DatetimeIndex):
125
+ # Missing values by time
126
+ time_missing = data.isnull().resample('M').sum()
127
+ time_patterns['monthly_missing'] = time_missing.sum(axis=1).to_dict()
128
+
129
+ # Missing values by day of week
130
+ data_with_dow = data.copy()
131
+ data_with_dow['dayofweek'] = data.index.dayofweek
132
+ dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
133
+ time_patterns['dayofweek_missing'] = dow_missing.to_dict()
134
+
135
+ self.missing_patterns = {
136
+ 'row_patterns': row_pattern_counts.to_dict(),
137
+ 'col_patterns': col_pattern_counts.to_dict(),
138
+ 'time_patterns': time_patterns,
139
+ 'missing_correlation': missing_matrix.corr().to_dict() # Missing value correlation
140
+ }
141
+
142
+ logger.debug(f"Found {len(row_pattern_counts)} unique row missing patterns")
143
+ logger.debug(f"Found {len(col_pattern_counts)} unique column missing patterns")
144
+
145
+ def _plot_missing_values(
146
+ self,
147
+ data: pd.DataFrame,
148
+ missing_df: pd.DataFrame
149
+ ) -> None:
150
+ """Visualise missing values"""
151
+ fig, axes = plt.subplots(3, 2, figsize=(16, 12))
152
+
153
+ # 1. Missing percentage histogram
154
+ axes[0, 0].barh(
155
+ missing_df.index,
156
+ missing_df['missing_percent']
157
+ )
158
+ axes[0, 0].axvline(self.config.missing_threshold, color='red', linestyle='--')
159
+ axes[0, 0].set_title('Missing Percentage by Column')
160
+ axes[0, 0].set_xlabel('Missing Percentage (%)')
161
+ axes[0, 0].set_ylabel('Columns')
162
+ axes[0, 0].grid(True, alpha=0.3)
163
+
164
+ # 2. Missing values heatmap
165
+ missing_matrix = data.isnull()
166
+ axes[0, 1].imshow(
167
+ missing_matrix.T if len(data) > 1000 else missing_matrix.T[:1000],
168
+ aspect='auto',
169
+ cmap='binary',
170
+ interpolation='none'
171
+ )
172
+ axes[0, 1].set_title('Missing Values Matrix')
173
+ axes[0, 1].set_xlabel('Observation Index')
174
+ axes[0, 1].set_ylabel('Variables')
175
+ axes[0, 1].set_yticks(range(len(data.columns)))
176
+ axes[0, 1].set_yticklabels(data.columns, fontsize=8)
177
+
178
+ # 3. Missing values over time (if time series)
179
+ if isinstance(data.index, pd.DatetimeIndex):
180
+ time_missing = data.isnull().resample('M').sum()
181
+
182
+ axes[1, 0].plot(time_missing.sum(axis=1))
183
+ axes[1, 0].set_title('Missing Values by Month')
184
+ axes[1, 0].set_xlabel('Date')
185
+ axes[1, 0].set_ylabel('Number of Missing Values')
186
+ axes[1, 0].grid(True, alpha=0.3)
187
+
188
+ # 4. Missing values by day of week
189
+ data_with_dow = data.copy()
190
+ data_with_dow['dayofweek'] = data.index.dayofweek
191
+ dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
192
+ dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
193
+
194
+ axes[1, 1].bar(range(7), dow_missing)
195
+ axes[1, 1].set_title('Missing Values by Day of Week')
196
+ axes[1, 1].set_xlabel('Day of Week')
197
+ axes[1, 1].set_ylabel('Number of Missing Values')
198
+ axes[1, 1].set_xticks(range(7))
199
+ axes[1, 1].set_xticklabels(dow_names)
200
+ axes[1, 1].grid(True, alpha=0.3)
201
+
202
+ # 5. Missing value correlation
203
+ missing_corr = data.isnull().corr()
204
+ im = axes[2, 0].imshow(
205
+ missing_corr,
206
+ cmap='coolwarm',
207
+ vmin=-1,
208
+ vmax=1,
209
+ aspect='auto'
210
+ )
211
+ axes[2, 0].set_title('Missing Value Correlation Between Variables')
212
+ axes[2, 0].set_xlabel('Variables')
213
+ axes[2, 0].set_ylabel('Variables')
214
+ plt.colorbar(im, ax=axes[2, 0])
215
+
216
+ # 6. Cumulative missing sum
217
+ cumulative_missing = data.isnull().cumsum()
218
+ for col in data.columns[:5]: # First 5 columns
219
+ if data[col].isnull().any():
220
+ axes[2, 1].plot(
221
+ cumulative_missing.index,
222
+ cumulative_missing[col],
223
+ label=col[:20]
224
+ )
225
+ axes[2, 1].set_title('Cumulative Missing Values')
226
+ axes[2, 1].set_xlabel('Time/Index')
227
+ axes[2, 1].set_ylabel('Cumulative Missing')
228
+ axes[2, 1].legend(fontsize=8)
229
+ axes[2, 1].grid(True, alpha=0.3)
230
+
231
+ plt.tight_layout()
232
+ plt.savefig(
233
+ f'{self.config.results_dir}/plots/missing_values_analysis.png',
234
+ dpi=300,
235
+ bbox_inches='tight'
236
+ )
237
+ plt.show()
238
+
239
+ def _log_missing_summary(self, missing_df: pd.DataFrame) -> None:
240
+ """Log missing value summary"""
241
+ missing_columns = missing_df[missing_df['missing_count'] > 0]
242
+
243
+ if len(missing_columns) > 0:
244
+ logger.info("MISSING VALUES FOUND:")
245
+ logger.info("-" * 50)
246
+ logger.info(f"Total missing values: {self.missing_info['overall']['total_missing']}")
247
+ logger.info(f"Overall missing percentage: {self.missing_info['overall']['overall_missing_percentage']:.2f}%")
248
+ logger.info(f"Rows with missing values: {self.missing_info['overall']['rows_with_any_missing']}")
249
+ logger.info(f"Columns with missing values: {len(self.missing_info['overall']['columns_with_missing'])}")
250
+
251
+ logger.info("\nTop-10 columns by missing values:")
252
+ top_missing = missing_df.nlargest(10, 'missing_percent')
253
+ for idx, (col, row) in enumerate(top_missing.iterrows(), 1):
254
+ logger.info(f" {idx:2d}. {col}: {int(row['missing_count'])} missing ({row['missing_percent']:.2f}%)")
255
+ else:
256
+ logger.info("✓ No missing values found")
257
+
258
+ def handle(
259
+ self,
260
+ data: pd.DataFrame,
261
+ method: str = 'interpolate',
262
+ strategy: str = 'columnwise',
263
+ **kwargs
264
+ ) -> pd.DataFrame:
265
+ """
266
+ Handle missing values
267
+
268
+ Parameters:
269
+ -----------
270
+ data : pd.DataFrame
271
+ Input data
272
+ method : str
273
+ Handling method: 'interpolate', 'ffill', 'bfill', 'mean', 'median', 'mode', 'knn', 'regression'
274
+ strategy : str
275
+ Strategy: 'columnwise', 'rowwise', 'global'
276
+ **kwargs : dict
277
+ Additional parameters for method
278
+
279
+ Returns:
280
+ --------
281
+ pd.DataFrame
282
+ Data with handled missing values
283
+ """
284
+ logger.info("\n" + "="*80)
285
+ logger.info("HANDLING MISSING VALUES")
286
+ logger.info("="*80)
287
+
288
+ data_processed = data.copy()
289
+ methods_applied = {}
290
+
291
+ # Determine columns to process
292
+ if strategy == 'columnwise':
293
+ columns_to_process = data_processed.columns
294
+ elif strategy == 'rowwise':
295
+ # Row-wise handling (for time series)
296
+ data_processed = self._handle_rowwise(data_processed, method, **kwargs)
297
+ return data_processed
298
+ else:
299
+ columns_to_process = data_processed.select_dtypes(include=[np.number]).columns
300
+
301
+ # Process each column
302
+ for col in columns_to_process:
303
+ missing_before = data_processed[col].isnull().sum()
304
+
305
+ if missing_before > 0:
306
+ # Check if missing percentage exceeds threshold
307
+ missing_percent = (missing_before / len(data_processed)) * 100
308
+
309
+ if missing_percent > self.config.missing_threshold:
310
+ logger.warning(f" {col}: {missing_before} missing ({missing_percent:.1f}%) > threshold {self.config.missing_threshold}%")
311
+
312
+ if kwargs.get('drop_high_missing', False):
313
+ data_processed = data_processed.drop(columns=[col])
314
+ method_used = f"dropped (>{self.config.missing_threshold}% missing)"
315
+ missing_after = 0
316
+ else:
317
+ # Use selected method
318
+ data_processed[col], method_used = self._apply_imputation_method(
319
+ data_processed[col], method, **kwargs
320
+ )
321
+ missing_after = data_processed[col].isnull().sum()
322
+ else:
323
+ # Use selected method
324
+ data_processed[col], method_used = self._apply_imputation_method(
325
+ data_processed[col], method, **kwargs
326
+ )
327
+ missing_after = data_processed[col].isnull().sum()
328
+
329
+ methods_applied[col] = {
330
+ 'method': method_used,
331
+ 'missing_before': int(missing_before),
332
+ 'missing_after': int(missing_after),
333
+ 'missing_percent_before': float(missing_percent)
334
+ }
335
+
336
+ if missing_before > 0:
337
+ logger.info(f" {col}: {missing_before} → {missing_after} missing ({method_used})")
338
+
339
+ self.handling_methods = methods_applied
340
+
341
+ # Check that all missing values are handled
342
+ remaining_missing = data_processed.isnull().sum().sum()
343
+ if remaining_missing == 0:
344
+ logger.info("✓ All missing values successfully handled")
345
+ else:
346
+ logger.warning(f"⚠ {remaining_missing} missing values remain")
347
+ # Additional handling of remaining missing values
348
+ data_processed = data_processed.fillna(method='ffill').fillna(method='bfill')
349
+ remaining_after = data_processed.isnull().sum().sum()
350
+ if remaining_after == 0:
351
+ logger.info("✓ Remaining missing values handled with ffill/bfill combination")
352
+
353
+ return data_processed
354
+
355
+ def _apply_imputation_method(
356
+ self,
357
+ series: pd.Series,
358
+ method: str,
359
+ **kwargs
360
+ ) -> Tuple[pd.Series, str]:
361
+ """
362
+ Apply imputation method to individual series
363
+
364
+ Parameters:
365
+ -----------
366
+ series : pd.Series
367
+ Input series
368
+ method : str
369
+ Imputation method
370
+ **kwargs : dict
371
+ Additional parameters
372
+
373
+ Returns:
374
+ --------
375
+ Tuple[pd.Series, str]
376
+ Processed series and method description
377
+ """
378
+ if method == 'interpolate':
379
+ # Interpolation for time series
380
+ if isinstance(series.index, pd.DatetimeIndex):
381
+ method_name = f"{kwargs.get('interpolation_method', 'linear')} interpolation"
382
+ series_filled = series.interpolate(
383
+ method=kwargs.get('interpolation_method', 'linear'),
384
+ limit_direction=kwargs.get('limit_direction', 'both'),
385
+ limit=kwargs.get('limit', None)
386
+ )
387
+ else:
388
+ method_name = 'linear interpolation'
389
+ series_filled = series.interpolate(method='linear')
390
+
391
+ elif method == 'time_weighted':
392
+ # Time-weighted interpolation
393
+ method_name = 'time-weighted interpolation'
394
+ series_filled = self._time_weighted_interpolation(series)
395
+
396
+ elif method == 'seasonal':
397
+ # Seasonal interpolation
398
+ method_name = 'seasonal interpolation'
399
+ series_filled = self._seasonal_interpolation(series, **kwargs)
400
+
401
+ elif method == 'ffill':
402
+ # Forward fill
403
+ method_name = 'forward fill'
404
+ series_filled = series.ffill(limit=kwargs.get('limit', None))
405
+
406
+ elif method == 'bfill':
407
+ # Backward fill
408
+ method_name = 'backward fill'
409
+ series_filled = series.bfill(limit=kwargs.get('limit', None))
410
+
411
+ elif method == 'mean':
412
+ # Mean imputation
413
+ method_name = 'mean imputation'
414
+ series_filled = series.fillna(series.mean())
415
+
416
+ elif method == 'median':
417
+ # Median imputation
418
+ method_name = 'median imputation'
419
+ series_filled = series.fillna(series.median())
420
+
421
+ elif method == 'mode':
422
+ # Mode imputation
423
+ method_name = 'mode imputation'
424
+ mode_value = series.mode()
425
+ if not mode_value.empty:
426
+ series_filled = series.fillna(mode_value.iloc[0])
427
+ else:
428
+ series_filled = series.fillna(series.median())
429
+
430
+ elif method == 'knn':
431
+ # KNN imputation
432
+ method_name = f"KNN imputation (k={kwargs.get('k', 5)})"
433
+ # Simplified version using nearest neighbour mean
434
+ series_filled = self._knn_imputation(series, k=kwargs.get('k', 5))
435
+
436
+ elif method == 'regression':
437
+ # Regression imputation
438
+ method_name = 'regression imputation'
439
+ series_filled = self._regression_imputation(series, **kwargs)
440
+
441
+ elif method == 'spline':
442
+ # Spline interpolation
443
+ method_name = 'spline interpolation'
444
+ series_filled = series.interpolate(method='spline', order=kwargs.get('order', 3))
445
+
446
+ elif method == 'stl':
447
+ # STL decomposition + interpolation
448
+ method_name = 'STL-based imputation'
449
+ series_filled = self._stl_imputation(series, **kwargs)
450
+
451
+ else:
452
+ raise ValueError(f"Unknown method: {method}")
453
+
454
+ # If missing values remain, fill with ffill/bfill
455
+ if series_filled.isnull().any():
456
+ series_filled = series_filled.ffill().bfill()
457
+ method_name += " + ffill/bfill"
458
+
459
+ return series_filled, method_name
460
+
461
+ def _time_weighted_interpolation(self, series: pd.Series) -> pd.Series:
462
+ """Time-weighted interpolation"""
463
+ if not isinstance(series.index, pd.DatetimeIndex):
464
+ return series.interpolate()
465
+
466
+ # Create timestamps
467
+ time_numeric = pd.Series(range(len(series)), index=series.index)
468
+
469
+ # Interpolate timestamps for missing values
470
+ time_interpolated = time_numeric.interpolate()
471
+
472
+ # Interpolate values based on timestamps
473
+ valid_mask = series.notna()
474
+ if valid_mask.sum() < 2:
475
+ return series.ffill().bfill()
476
+
477
+ # Use linear interpolation
478
+ valid_times = time_numeric[valid_mask]
479
+ valid_values = series[valid_mask]
480
+
481
+ # Interpolation
482
+ interp_func = interp1d(
483
+ valid_times,
484
+ valid_values,
485
+ kind='linear',
486
+ bounds_error=False,
487
+ fill_value='extrapolate'
488
+ )
489
+
490
+ series_filled = series.copy()
491
+ missing_mask = series.isna()
492
+ series_filled[missing_mask] = interp_func(time_interpolated[missing_mask])
493
+
494
+ return series_filled
495
+
496
+ def _seasonal_interpolation(
497
+ self,
498
+ series: pd.Series,
499
+ **kwargs
500
+ ) -> pd.Series:
501
+ """Seasonal interpolation"""
502
+ if not isinstance(series.index, pd.DatetimeIndex):
503
+ return series.interpolate()
504
+
505
+ period = kwargs.get('period', self.config.seasonal_period)
506
+
507
+ # Create series copy
508
+ series_filled = series.copy()
509
+
510
+ # Interpolation considering seasonality
511
+ for i in range(len(series)):
512
+ if pd.isna(series.iloc[i]):
513
+ # Find values at same seasonal position
514
+ seasonal_indices = []
515
+ for offset in range(1, 10): # Look in previous/next cycles
516
+ idx_back = i - offset * period
517
+ idx_forward = i + offset * period
518
+
519
+ if idx_back >= 0 and not pd.isna(series.iloc[idx_back]):
520
+ seasonal_indices.append(idx_back)
521
+
522
+ if idx_forward < len(series) and not pd.isna(series.iloc[idx_forward]):
523
+ seasonal_indices.append(idx_forward)
524
+
525
+ if seasonal_indices:
526
+ # Take mean value from seasonal positions
527
+ seasonal_values = series.iloc[seasonal_indices]
528
+ series_filled.iloc[i] = seasonal_values.mean()
529
+
530
+ # Fill remaining missing values with regular interpolation
531
+ series_filled = series_filled.interpolate()
532
+
533
+ return series_filled
534
+
535
+ def _knn_imputation(
536
+ self,
537
+ series: pd.Series,
538
+ k: int = 5
539
+ ) -> pd.Series:
540
+ """KNN imputation for time series"""
541
+ # Simplified KNN for time series
542
+ series_filled = series.copy()
543
+
544
+ for i in range(len(series)):
545
+ if pd.isna(series.iloc[i]):
546
+ # Find nearest k non-missing values
547
+ distances = []
548
+ values = []
549
+
550
+ for j in range(max(0, i - k * 10), min(len(series), i + k * 10)):
551
+ if j != i and not pd.isna(series.iloc[j]):
552
+ distance = abs(i - j)
553
+ distances.append(distance)
554
+ values.append(series.iloc[j])
555
+
556
+ if len(values) >= k:
557
+ break
558
+
559
+ if values:
560
+ # Distance-weighted average
561
+ weights = [1 / (d + 1) for d in distances]
562
+ weighted_avg = np.average(values, weights=weights)
563
+ series_filled.iloc[i] = weighted_avg
564
+
565
+ return series_filled
566
+
567
+ def _regression_imputation(
568
+ self,
569
+ series: pd.Series,
570
+ **kwargs
571
+ ) -> pd.Series:
572
+ """Regression imputation based on neighbouring values"""
573
+ # Simplified regression for time series
574
+ series_filled = series.copy()
575
+
576
+ if series.notna().sum() < 3:
577
+ return series.ffill().bfill()
578
+
579
+ # Use polynomial regression
580
+ x = np.arange(len(series))
581
+ y = series.values
582
+
583
+ # Valid values mask
584
+ valid_mask = ~np.isnan(y)
585
+
586
+ if valid_mask.sum() < 2:
587
+ return series.ffill().bfill()
588
+
589
+ # Polynomial regression degree 2
590
+ coeffs = np.polyfit(x[valid_mask], y[valid_mask], 2)
591
+ poly_func = np.poly1d(coeffs)
592
+
593
+ # Fill missing values
594
+ missing_mask = np.isnan(y)
595
+ series_filled.iloc[missing_mask] = poly_func(x[missing_mask])
596
+
597
+ return series_filled
598
+
599
+ def _stl_imputation(
600
+ self,
601
+ series: pd.Series,
602
+ **kwargs
603
+ ) -> pd.Series:
604
+ """STL decomposition-based imputation"""
605
+ try:
606
+ if not isinstance(series.index, pd.DatetimeIndex):
607
+ return series.interpolate()
608
+
609
+ # STL decomposition
610
+ stl = STL(
611
+ series.ffill().bfill(), # Fill missing for STL
612
+ period=kwargs.get('period', self.config.seasonal_period),
613
+ robust=True
614
+ )
615
+ result = stl.fit()
616
+
617
+ # Reconstruct series without noise
618
+ reconstructed = result.trend + result.seasonal
619
+
620
+ # Replace missing values with reconstructed values
621
+ series_filled = series.copy()
622
+ missing_mask = series.isna()
623
+ series_filled[missing_mask] = reconstructed[missing_mask]
624
+
625
+ return series_filled
626
+
627
+ except Exception as e:
628
+ logger.warning(f"STL imputation failed: {e}, using interpolation")
629
+ return series.interpolate()
630
+
631
+ def _handle_rowwise(
632
+ self,
633
+ data: pd.DataFrame,
634
+ method: str,
635
+ **kwargs
636
+ ) -> pd.DataFrame:
637
+ """Row-wise missing value handling"""
638
+ data_processed = data.copy()
639
+
640
+ # Remove rows with high missing counts
641
+ if kwargs.get('drop_rows_threshold', 0) > 0:
642
+ threshold = kwargs['drop_rows_threshold']
643
+ rows_before = len(data_processed)
644
+ missing_per_row = data_processed.isnull().sum(axis=1) / data_processed.shape[1] * 100
645
+ rows_to_drop = missing_per_row[missing_per_row > threshold].index
646
+ data_processed = data_processed.drop(rows_to_drop)
647
+ rows_after = len(data_processed)
648
+ logger.info(f"Rows removed: {rows_before - rows_after} (missing > {threshold}%)")
649
+
650
+ # Row-wise imputation
651
+ if method == 'row_mean':
652
+ data_processed = data_processed.T.fillna(data_processed.mean(axis=1)).T
653
+ elif method == 'row_median':
654
+ data_processed = data_processed.T.fillna(data_processed.median(axis=1)).T
655
+ elif method == 'row_ffill':
656
+ data_processed = data_processed.ffill(axis=1).bfill(axis=1)
657
+
658
+ return data_processed
659
+
660
+ def create_validation_rules(self) -> Dict:
661
+ """Create validation rules based on missing value analysis"""
662
+ rules = {}
663
+
664
+ for col, info in self.missing_info['summary'].items():
665
+ missing_percent = info['missing_percent']
666
+
667
+ if missing_percent > 50:
668
+ rules[col] = {
669
+ 'action': 'drop_column',
670
+ 'reason': f'Missing > 50%: {missing_percent:.1f}%'
671
+ }
672
+ elif missing_percent > 20:
673
+ rules[col] = {
674
+ 'action': 'advanced_imputation',
675
+ 'reason': f'High missing: {missing_percent:.1f}%',
676
+ 'recommended_method': 'knn'
677
+ }
678
+ elif missing_percent > 5:
679
+ rules[col] = {
680
+ 'action': 'standard_imputation',
681
+ 'reason': f'Moderate missing: {missing_percent:.1f}%',
682
+ 'recommended_method': 'interpolate'
683
+ }
684
+ elif missing_percent > 0:
685
+ rules[col] = {
686
+ 'action': 'simple_imputation',
687
+ 'reason': f'Low missing: {missing_percent:.1f}%',
688
+ 'recommended_method': 'ffill'
689
+ }
690
+
691
+ return rules
692
+
693
+ def get_report(self) -> Dict:
694
+ """Get missing values report"""
695
+ return {
696
+ 'missing_info': self.missing_info,
697
+ 'handling_methods': self.handling_methods,
698
+ 'missing_patterns': self.missing_patterns,
699
+ 'validation_rules': self.create_validation_rules()
700
+ }
outliers/__init__.py ADDED
File without changes
outliers/outlier_analyzer.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 4: OUTLIER ANALYSER
3
+ # ============================================
4
+ from typing import Dict, List, Tuple
5
+ from venv import logger
6
+
7
+ from config.config import Config
8
+ import pandas as pd
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.neighbors import LocalOutlierFactor
12
+ from sklearn.covariance import EllipticEnvelope
13
+ from scipy import stats
14
+
15
+ class OutlierAnalyser:
16
+ """Class for analysing and handling outliers"""
17
+
18
+ def __init__(self, config: Config):
19
+ """
20
+ Initialise outlier analyser
21
+
22
+ Parameters:
23
+ -----------
24
+ config : Config
25
+ Experiment configuration
26
+ """
27
+ self.config = config
28
+ self.outlier_info = {}
29
+ self.handling_methods = {}
30
+ self.detection_methods = {}
31
+ self.outlier_models = {}
32
+
33
+ def analyse(
34
+ self,
35
+ data: pd.DataFrame,
36
+ method: str = None,
37
+ columns: List[str] = None,
38
+ **kwargs
39
+ ) -> Dict:
40
+ """
41
+ Analyse outliers in data
42
+
43
+ Parameters:
44
+ -----------
45
+ data : pd.DataFrame
46
+ Input data
47
+ method : str, optional
48
+ Detection method. If None, uses configuration value.
49
+ columns : List[str], optional
50
+ List of columns to analyse. If None, uses all numeric columns.
51
+ **kwargs : dict
52
+ Additional parameters for method
53
+
54
+ Returns:
55
+ --------
56
+ Dict
57
+ Information about outliers
58
+ """
59
+ logger.info("\n" + "="*80)
60
+ logger.info("OUTLIER ANALYSIS")
61
+ logger.info("="*80)
62
+
63
+ method = method or self.config.outlier_method
64
+ if columns is None:
65
+ columns = data.select_dtypes(include=[np.number]).columns
66
+
67
+ outliers_info = {}
68
+
69
+ # Apply various detection methods
70
+ detection_results = {}
71
+
72
+ # 1. Statistical methods
73
+ if method in ['iqr', 'zscore', 'sigma', 'all']:
74
+ detection_results.update(self._statistical_methods(data, columns, method, **kwargs))
75
+
76
+ # 2. ML methods
77
+ if method in ['lof', 'isolation_forest', 'elliptic_envelope', 'all']:
78
+ detection_results.update(self._ml_methods(data, columns, method, **kwargs))
79
+
80
+ # 3. Temporal methods
81
+ if isinstance(data.index, pd.DatetimeIndex):
82
+ detection_results.update(self._temporal_methods(data, columns, **kwargs))
83
+
84
+ # Aggregate results
85
+ for col in columns:
86
+ if col in detection_results:
87
+ # Combine results from different methods
88
+ combined_mask = self._combine_detection_methods(detection_results, col)
89
+
90
+ outliers_count = combined_mask.sum()
91
+ outliers_percent = (outliers_count / len(data)) * 100
92
+
93
+ # Detailed information
94
+ col_data = data[col].dropna()
95
+ stats = {
96
+ 'mean': float(col_data.mean()),
97
+ 'std': float(col_data.std()),
98
+ 'median': float(col_data.median()),
99
+ 'q1': float(col_data.quantile(0.25)),
100
+ 'q3': float(col_data.quantile(0.75)),
101
+ 'min': float(col_data.min()),
102
+ 'max': float(col_data.max()),
103
+ 'skewness': float(col_data.skew()),
104
+ 'kurtosis': float(col_data.kurtosis())
105
+ }
106
+
107
+ outliers_info[col] = {
108
+ 'method': method,
109
+ 'statistics': stats,
110
+ 'outliers_count': int(outliers_count),
111
+ 'outliers_percent': float(outliers_percent),
112
+ 'outlier_indices': data[combined_mask].index.tolist() if outliers_count > 0 else [],
113
+ 'outlier_values': data.loc[combined_mask, col].tolist() if outliers_count > 0 else [],
114
+ 'detection_methods': {
115
+ name: {
116
+ 'count': int(mask.sum()),
117
+ 'percent': float(mask.sum() / len(data) * 100)
118
+ }
119
+ for name, mask in detection_results[col].items()
120
+ }
121
+ }
122
+
123
+ logger.info(f"{col}: {outliers_count} outliers ({outliers_percent:.2f}%)")
124
+
125
+ self.outlier_info = outliers_info
126
+ self.detection_methods = detection_results
127
+
128
+ # Visualisation
129
+ if self.config.save_plots and len(columns) > 0:
130
+ self._plot_outlier_analysis(data, columns, outliers_info)
131
+
132
+ return outliers_info
133
+
134
+ def _statistical_methods(
135
+ self,
136
+ data: pd.DataFrame,
137
+ columns: List[str],
138
+ method: str,
139
+ **kwargs
140
+ ) -> Dict:
141
+ """Statistical outlier detection methods"""
142
+ results = {}
143
+
144
+ for col in columns:
145
+ col_results = {}
146
+ series = data[col].dropna()
147
+
148
+ if len(series) < 3:
149
+ continue
150
+
151
+ # IQR method
152
+ if method in ['iqr', 'all']:
153
+ q1 = series.quantile(0.25)
154
+ q3 = series.quantile(0.75)
155
+ iqr = q3 - q1
156
+ lower_bound = q1 - self.config.outlier_alpha * iqr
157
+ upper_bound = q3 + self.config.outlier_alpha * iqr
158
+
159
+ iqr_mask = (data[col] < lower_bound) | (data[col] > upper_bound)
160
+ col_results['iqr'] = iqr_mask
161
+
162
+ # Z-score method
163
+ if method in ['zscore', 'sigma', 'all']:
164
+ z_threshold = kwargs.get('z_threshold', 3)
165
+ z_scores = np.abs((data[col] - series.mean()) / series.std())
166
+ z_mask = z_scores > z_threshold
167
+ col_results['zscore'] = z_mask
168
+
169
+ # Modified Z-score method
170
+ if method in ['zscore', 'all']:
171
+ median = series.median()
172
+ mad = np.median(np.abs(series - median))
173
+ if mad != 0:
174
+ modified_z_scores = 0.6745 * (data[col] - median) / mad
175
+ mz_mask = np.abs(modified_z_scores) > 3.5
176
+ col_results['modified_zscore'] = mz_mask
177
+
178
+ # Tukey's fences
179
+ if method in ['iqr', 'all']:
180
+ inner_lower = q1 - 1.5 * iqr
181
+ inner_upper = q3 + 1.5 * iqr
182
+ outer_lower = q1 - 3 * iqr
183
+ outer_upper = q3 + 3 * iqr
184
+
185
+ mild_mask = ((data[col] < inner_lower) | (data[col] > inner_upper)) & \
186
+ ((data[col] >= outer_lower) & (data[col] <= outer_upper))
187
+ extreme_mask = (data[col] < outer_lower) | (data[col] > outer_upper)
188
+
189
+ col_results['tukey_mild'] = mild_mask
190
+ col_results['tukey_extreme'] = extreme_mask
191
+
192
+ results[col] = col_results
193
+
194
+ return results
195
+
196
+ def _ml_methods(
197
+ self,
198
+ data: pd.DataFrame,
199
+ columns: List[str],
200
+ method: str,
201
+ **kwargs
202
+ ) -> Dict:
203
+ """ML outlier detection methods"""
204
+ results = {}
205
+
206
+ numeric_data = data[columns].dropna()
207
+
208
+ if len(numeric_data) < 10:
209
+ return results
210
+
211
+ try:
212
+ # Local Outlier Factor
213
+ if method in ['lof', 'all']:
214
+ lof = LocalOutlierFactor(
215
+ contamination=self.config.outlier_contamination,
216
+ n_neighbors=kwargs.get('n_neighbors', 20)
217
+ )
218
+ lof_labels = lof.fit_predict(numeric_data)
219
+ lof_mask = pd.Series(lof_labels == -1, index=numeric_data.index)
220
+
221
+ for col in columns:
222
+ if col in numeric_data.columns:
223
+ if col not in results:
224
+ results[col] = {}
225
+ results[col]['lof'] = lof_mask
226
+
227
+ # Elliptic Envelope
228
+ if method in ['elliptic_envelope', 'all']:
229
+ try:
230
+ envelope = EllipticEnvelope(
231
+ contamination=self.config.outlier_contamination,
232
+ random_state=42
233
+ )
234
+ envelope_labels = envelope.fit_predict(numeric_data)
235
+ envelope_mask = pd.Series(envelope_labels == -1, index=numeric_data.index)
236
+
237
+ for col in columns:
238
+ if col in numeric_data.columns:
239
+ if col not in results:
240
+ results[col] = {}
241
+ results[col]['elliptic_envelope'] = envelope_mask
242
+ except Exception as e:
243
+ logger.warning(f"Elliptic Envelope failed: {e}")
244
+
245
+ except Exception as e:
246
+ logger.warning(f"ML outlier detection methods failed: {e}")
247
+
248
+ return results
249
+
250
+ def _temporal_methods(
251
+ self,
252
+ data: pd.DataFrame,
253
+ columns: List[str],
254
+ **kwargs
255
+ ) -> Dict:
256
+ """Outlier detection methods for time series"""
257
+ results = {}
258
+
259
+ for col in columns:
260
+ col_results = {}
261
+ series = data[col].dropna()
262
+
263
+ if len(series) < 30:
264
+ continue
265
+
266
+ # Rolling statistics method
267
+ window = kwargs.get('temporal_window', 30)
268
+ rolling_mean = series.rolling(window=window, center=True).mean()
269
+ rolling_std = series.rolling(window=window, center=True).std()
270
+
271
+ # Outliers relative to moving average
272
+ threshold = kwargs.get('temporal_threshold', 3)
273
+ temporal_mask = np.abs(series - rolling_mean) > (threshold * rolling_std)
274
+ col_results['temporal'] = temporal_mask
275
+
276
+ # Seasonal detrending + outlier detection
277
+ try:
278
+ # Simple seasonal detrending
279
+ if len(series) > 365:
280
+ seasonal_period = kwargs.get('seasonal_period', 365)
281
+ seasonal_mean = series.rolling(window=seasonal_period, center=True).mean()
282
+ detrended = series - seasonal_mean
283
+
284
+ # Outliers in detrended series
285
+ q1 = detrended.quantile(0.25)
286
+ q3 = detrended.quantile(0.75)
287
+ iqr = q3 - q1
288
+ seasonal_lower = q1 - 3 * iqr
289
+ seasonal_upper = q3 + 3 * iqr
290
+
291
+ seasonal_mask = (detrended < seasonal_lower) | (detrended > seasonal_upper)
292
+ col_results['seasonal'] = seasonal_mask
293
+ except Exception as e:
294
+ logger.debug(f"Seasonal outlier detection failed for {col}: {e}")
295
+
296
+ results[col] = col_results
297
+
298
+ return results
299
+
300
+ def _combine_detection_methods(
301
+ self,
302
+ detection_results: Dict,
303
+ column: str
304
+ ) -> pd.Series:
305
+ """Combine results from different detection methods"""
306
+ if column not in detection_results:
307
+ return pd.Series(False, index=pd.RangeIndex(0))
308
+
309
+ methods = detection_results[column]
310
+ combined_mask = None
311
+
312
+ for method_name, mask in methods.items():
313
+ if combined_mask is None:
314
+ combined_mask = mask.copy()
315
+ else:
316
+ # Combine via OR (outlier by any method)
317
+ combined_mask = combined_mask | mask
318
+
319
+ return combined_mask.fillna(False)
320
+
321
+ def _plot_outlier_analysis(
322
+ self,
323
+ data: pd.DataFrame,
324
+ columns: List[str],
325
+ outliers_info: Dict
326
+ ) -> None:
327
+ """Visualise outlier analysis"""
328
+ n_cols = min(len(columns), 4)
329
+ n_rows = (len(columns) + n_cols - 1) // n_cols
330
+
331
+ fig = plt.figure(figsize=(16, 4 * n_rows))
332
+ gs = fig.add_gridspec(n_rows, n_cols)
333
+
334
+ for idx, col in enumerate(columns):
335
+ if col not in outliers_info:
336
+ continue
337
+
338
+ row = idx // n_cols
339
+ col_idx = idx % n_cols
340
+
341
+ ax = fig.add_subplot(gs[row, col_idx])
342
+
343
+ # Data
344
+ series = data[col].dropna()
345
+
346
+ # 1. Box plot
347
+ bp = ax.boxplot(
348
+ series.values,
349
+ vert=True,
350
+ patch_artist=True,
351
+ widths=0.6,
352
+ showfliers=False
353
+ )
354
+
355
+ # Colours for box plot
356
+ bp['boxes'][0].set_facecolor('lightblue')
357
+ bp['medians'][0].set_color('red')
358
+ bp['whiskers'][0].set_color('black')
359
+ bp['whiskers'][1].set_color('black')
360
+ bp['caps'][0].set_color('black')
361
+ bp['caps'][1].set_color('black')
362
+
363
+ # 2. Outliers
364
+ if outliers_info[col]['outliers_count'] > 0:
365
+ outlier_indices = outliers_info[col]['outlier_indices']
366
+ outlier_values = outliers_info[col]['outlier_values']
367
+
368
+ # Convert indices to positions for box plot
369
+ jitter = np.random.normal(0, 0.05, len(outlier_values))
370
+
371
+ ax.scatter(
372
+ np.ones(len(outlier_values)) + jitter,
373
+ outlier_values,
374
+ color='red',
375
+ alpha=0.6,
376
+ s=30,
377
+ edgecolors='black',
378
+ label=f'Outliers ({outliers_info[col]["outliers_count"]})'
379
+ )
380
+
381
+ # 3. Histogram on same plot
382
+ ax2 = ax.twinx()
383
+ ax2.hist(
384
+ series.values,
385
+ bins=30,
386
+ alpha=0.3,
387
+ color='green',
388
+ density=True
389
+ )
390
+
391
+ # 4. Normal distribution for comparison
392
+ if len(series) > 10:
393
+ xmin, xmax = ax.get_xlim()
394
+ x = np.linspace(series.min(), series.max(), 100)
395
+ mean = series.mean()
396
+ std = series.std()
397
+
398
+ if std > 0:
399
+ p = stats.norm.pdf(x, mean, std)
400
+ ax2.plot(x, p, 'k--', linewidth=1, label='Normal distribution')
401
+
402
+ ax.set_title(f'{col}\nOutliers: {outliers_info[col]["outliers_count"]} ({outliers_info[col]["outliers_percent"]:.1f}%)')
403
+ ax.set_ylabel('Value')
404
+ ax.grid(True, alpha=0.3)
405
+
406
+ # Legend
407
+ if outliers_info[col]['outliers_count'] > 0:
408
+ ax.legend(loc='upper right', fontsize=8)
409
+ ax2.legend(loc='upper left', fontsize=8)
410
+
411
+ plt.tight_layout()
412
+ plt.savefig(
413
+ f'{self.config.results_dir}/plots/outliers_analysis.png',
414
+ dpi=300,
415
+ bbox_inches='tight'
416
+ )
417
+ plt.show()
418
+
419
+ # Additional plots for time series
420
+ if isinstance(data.index, pd.DatetimeIndex) and len(columns) > 0:
421
+ self._plot_temporal_outliers(data, columns, outliers_info)
422
+
423
+ def _plot_temporal_outliers(
424
+ self,
425
+ data: pd.DataFrame,
426
+ columns: List[str],
427
+ outliers_info: Dict
428
+ ) -> None:
429
+ """Visualise outliers over time"""
430
+ n_plots = min(len(columns), 3)
431
+
432
+ fig, axes = plt.subplots(n_plots, 1, figsize=(14, 4 * n_plots))
433
+ if n_plots == 1:
434
+ axes = [axes]
435
+
436
+ for idx, (col, ax) in enumerate(zip(columns[:n_plots], axes)):
437
+ if col not in outliers_info:
438
+ continue
439
+
440
+ # Time series
441
+ ax.plot(data.index, data[col], alpha=0.7, linewidth=1, label='Original series')
442
+
443
+ # Outliers
444
+ if outliers_info[col]['outliers_count'] > 0:
445
+ outlier_indices = outliers_info[col]['outlier_indices']
446
+ outlier_values = outliers_info[col]['outlier_values']
447
+
448
+ ax.scatter(
449
+ outlier_indices,
450
+ outlier_values,
451
+ color='red',
452
+ s=40,
453
+ edgecolors='black',
454
+ zorder=5,
455
+ label='Outliers'
456
+ )
457
+
458
+ # Moving average
459
+ if len(data) > 30:
460
+ rolling_mean = data[col].rolling(window=30, center=True).mean()
461
+ ax.plot(data.index, rolling_mean, 'orange', linewidth=2, label='Moving average (30)')
462
+
463
+ ax.set_title(f'Outliers over time: {col}')
464
+ ax.set_xlabel('Date')
465
+ ax.set_ylabel(col)
466
+ ax.legend(fontsize=8)
467
+ ax.grid(True, alpha=0.3)
468
+
469
+ plt.tight_layout()
470
+ plt.savefig(
471
+ f'{self.config.results_dir}/plots/temporal_outliers.png',
472
+ dpi=300,
473
+ bbox_inches='tight'
474
+ )
475
+ plt.show()
476
+
477
+ def handle(
478
+ self,
479
+ data: pd.DataFrame,
480
+ method: str = 'clip',
481
+ strategy: str = 'columnwise',
482
+ **kwargs
483
+ ) -> pd.DataFrame:
484
+ """
485
+ Handle outliers
486
+
487
+ Parameters:
488
+ -----------
489
+ data : pd.DataFrame
490
+ Input data
491
+ method : str
492
+ Handling method: 'clip', 'remove', 'mean', 'median', 'winsorize', 'transform', 'impute'
493
+ strategy : str
494
+ Strategy: 'columnwise', 'global', 'adaptive'
495
+ **kwargs : dict
496
+ Additional parameters for method
497
+
498
+ Returns:
499
+ --------
500
+ pd.DataFrame
501
+ Data with handled outliers
502
+ """
503
+ logger.info("\n" + "="*80)
504
+ logger.info("HANDLING OUTLIERS")
505
+ logger.info("="*80)
506
+
507
+ if not self.outlier_info:
508
+ logger.warning("⚠ Perform outlier analysis first")
509
+ return data
510
+
511
+ data_processed = data.copy()
512
+ methods_applied = {}
513
+
514
+ for col, info in self.outlier_info.items():
515
+ if col not in data_processed.columns:
516
+ continue
517
+
518
+ outliers_count = info['outliers_count']
519
+
520
+ if outliers_count > 0:
521
+ # Create outlier mask
522
+ outlier_mask = pd.Series(False, index=data_processed.index)
523
+ if info['outlier_indices']:
524
+ outlier_indices = [idx for idx in info['outlier_indices'] if idx in data_processed.index]
525
+ outlier_mask.loc[outlier_indices] = True
526
+
527
+ # Determine boundaries
528
+ stats = info['statistics']
529
+ q1, q3 = stats['q1'], stats['q3']
530
+ iqr = q3 - q1
531
+ lower_bound = q1 - self.config.outlier_alpha * iqr
532
+ upper_bound = q3 + self.config.outlier_alpha * iqr
533
+
534
+ if method == 'clip':
535
+ # Clip values to boundaries
536
+ data_processed[col] = data_processed[col].clip(
537
+ lower=lower_bound,
538
+ upper=upper_bound
539
+ )
540
+ method_used = 'clipping'
541
+ affected = outliers_count
542
+
543
+ elif method == 'remove':
544
+ # Remove rows with outliers
545
+ data_processed = data_processed[~outlier_mask]
546
+ method_used = 'removal'
547
+ affected = outliers_count
548
+
549
+ elif method == 'mean':
550
+ # Replace outliers with mean value
551
+ mean_val = data_processed[col].mean()
552
+ data_processed.loc[outlier_mask, col] = mean_val
553
+ method_used = 'mean imputation'
554
+ affected = outliers_count
555
+
556
+ elif method == 'median':
557
+ # Replace outliers with median
558
+ median_val = data_processed[col].median()
559
+ data_processed.loc[outlier_mask, col] = median_val
560
+ method_used = 'median imputation'
561
+ affected = outliers_count
562
+
563
+ elif method == 'winsorize':
564
+ # Winsorisation
565
+ data_processed[col] = self._winsorize_series(
566
+ data_processed[col],
567
+ limits=kwargs.get('limits', (0.05, 0.05))
568
+ )
569
+ method_used = 'winsorization'
570
+ affected = outliers_count
571
+
572
+ elif method == 'transform':
573
+ # Transformation to reduce outlier impact
574
+ transform_method = kwargs.get('transform_method', 'log')
575
+ data_processed[col] = self._transform_series(
576
+ data_processed[col],
577
+ method=transform_method
578
+ )
579
+ method_used = f'{transform_method} transformation'
580
+ affected = 'all' # Transformation applied to all values
581
+
582
+ elif method == 'impute':
583
+ # Smart outlier imputation
584
+ impute_method = kwargs.get('impute_method', 'neighbors')
585
+ data_processed[col] = self._impute_outliers(
586
+ data_processed[col],
587
+ outlier_mask,
588
+ method=impute_method,
589
+ **kwargs
590
+ )
591
+ method_used = f'{impute_method} imputation'
592
+ affected = outliers_count
593
+
594
+ elif method == 'adaptive':
595
+ # Adaptive handling
596
+ data_processed[col] = self._adaptive_outlier_handling(
597
+ data_processed[col],
598
+ outlier_mask,
599
+ **kwargs
600
+ )
601
+ method_used = 'adaptive handling'
602
+ affected = outliers_count
603
+
604
+ else:
605
+ raise ValueError(f"Unknown method: {method}")
606
+
607
+ methods_applied[col] = {
608
+ 'method': method_used,
609
+ 'outliers_before': outliers_count,
610
+ 'affected': affected,
611
+ 'bounds': {
612
+ 'lower': float(lower_bound),
613
+ 'upper': float(upper_bound)
614
+ }
615
+ }
616
+
617
+ logger.info(f" {col}: {outliers_count} outliers handled ({method_used})")
618
+
619
+ self.handling_methods = methods_applied
620
+
621
+ # Handling statistics
622
+ total_outliers = sum(info['outliers_count'] for info in self.outlier_info.values())
623
+ total_affected = sum(method['affected'] for method in methods_applied.values()
624
+ if isinstance(method['affected'], (int, np.integer)))
625
+
626
+ logger.info(f"\n✓ {total_affected} out of {total_outliers} outliers handled")
627
+ logger.info(f" Data size before: {len(data)} rows")
628
+ logger.info(f" Data size after: {len(data_processed)} rows")
629
+
630
+ # Visualise results
631
+ if self.config.save_plots and methods_applied:
632
+ self._plot_outlier_handling_results(data, data_processed, methods_applied)
633
+
634
+ return data_processed
635
+
636
+ def _winsorize_series(
637
+ self,
638
+ series: pd.Series,
639
+ limits: Tuple[float, float] = (0.05, 0.05)
640
+ ) -> pd.Series:
641
+ """Winsorize series"""
642
+ from scipy.stats.mstats import winsorize
643
+ try:
644
+ winsorized = winsorize(series.values, limits=limits)
645
+ return pd.Series(winsorized, index=series.index)
646
+ except:
647
+ return series
648
+
649
+ def _transform_series(
650
+ self,
651
+ series: pd.Series,
652
+ method: str = 'log'
653
+ ) -> pd.Series:
654
+ """Transform series to reduce outlier impact"""
655
+ series_transformed = series.copy()
656
+
657
+ if method == 'log':
658
+ # Logarithmic transformation
659
+ min_val = series.min()
660
+ if min_val <= 0:
661
+ shift = abs(min_val) + 1
662
+ series_transformed = np.log(series + shift)
663
+ else:
664
+ series_transformed = np.log(series)
665
+
666
+ elif method == 'boxcox':
667
+ # Box-Cox transformation
668
+ try:
669
+ from scipy.stats import boxcox
670
+ transformed, _ = boxcox(series - series.min() + 1)
671
+ series_transformed = pd.Series(transformed, index=series.index)
672
+ except:
673
+ logger.warning("Box-Cox transformation failed, using log")
674
+ return self._transform_series(series, 'log')
675
+
676
+ elif method == 'sqrt':
677
+ # Square root
678
+ min_val = series.min()
679
+ if min_val < 0:
680
+ series_transformed = np.sqrt(series - min_val)
681
+ else:
682
+ series_transformed = np.sqrt(series)
683
+
684
+ elif method == 'yeojohnson':
685
+ # Yeo-Johnson transformation
686
+ try:
687
+ from scipy.stats import yeojohnson
688
+ transformed, _ = yeojohnson(series)
689
+ series_transformed = pd.Series(transformed, index=series.index)
690
+ except:
691
+ logger.warning("Yeo-Johnson transformation failed, using log")
692
+ return self._transform_series(series, 'log')
693
+
694
+ return series_transformed
695
+
696
+ def _impute_outliers(
697
+ self,
698
+ series: pd.Series,
699
+ outlier_mask: pd.Series,
700
+ method: str = 'neighbors',
701
+ **kwargs
702
+ ) -> pd.Series:
703
+ """Smart outlier imputation"""
704
+ series_imputed = series.copy()
705
+
706
+ if method == 'neighbors':
707
+ # Replace with mean of neighbouring values
708
+ for idx in series[outlier_mask].index:
709
+ if idx in series.index:
710
+ pos = series.index.get_loc(idx)
711
+ neighbours = []
712
+
713
+ # Find nearest non-outliers
714
+ for offset in range(1, 6):
715
+ if pos - offset >= 0 and not outlier_mask.iloc[pos - offset]:
716
+ neighbours.append(series.iloc[pos - offset])
717
+ break
718
+
719
+ for offset in range(1, 6):
720
+ if pos + offset < len(series) and not outlier_mask.iloc[pos + offset]:
721
+ neighbours.append(series.iloc[pos + offset])
722
+ break
723
+
724
+ if neighbours:
725
+ series_imputed.loc[idx] = np.mean(neighbours)
726
+
727
+ elif method == 'interpolate':
728
+ # Interpolation
729
+ series_imputed = series.mask(outlier_mask).interpolate()
730
+
731
+ elif method == 'rolling':
732
+ # Replace with moving average
733
+ window = kwargs.get('window', 5)
734
+ rolling_mean = series.rolling(window=window, center=True, min_periods=1).mean()
735
+ series_imputed = series.mask(outlier_mask, rolling_mean)
736
+
737
+ return series_imputed
738
+
739
+ def _adaptive_outlier_handling(
740
+ self,
741
+ series: pd.Series,
742
+ outlier_mask: pd.Series,
743
+ **kwargs
744
+ ) -> pd.Series:
745
+ """Adaptive outlier handling"""
746
+ series_processed = series.copy()
747
+ outlier_indices = series[outlier_mask].index
748
+
749
+ for idx in outlier_indices:
750
+ if idx in series.index:
751
+ value = series.loc[idx]
752
+ stats = self.outlier_info.get(series.name, {}).get('statistics', {})
753
+
754
+ # Determine outlier type
755
+ q1 = stats.get('q1', series.quantile(0.25))
756
+ q3 = stats.get('q3', series.quantile(0.75))
757
+ iqr = q3 - q1
758
+
759
+ if value < q1 - 3 * iqr:
760
+ # Extreme low outlier
761
+ series_processed.loc[idx] = q1 - 1.5 * iqr
762
+ elif value > q3 + 3 * iqr:
763
+ # Extreme high outlier
764
+ series_processed.loc[idx] = q3 + 1.5 * iqr
765
+ else:
766
+ # Moderate outlier
767
+ pos = series.index.get_loc(idx)
768
+ # Use linear interpolation
769
+ if pos > 0 and pos < len(series) - 1:
770
+ series_processed.loc[idx] = (series.iloc[pos-1] + series.iloc[pos+1]) / 2
771
+
772
+ return series_processed
773
+
774
+ def _plot_outlier_handling_results(
775
+ self,
776
+ original_data: pd.DataFrame,
777
+ processed_data: pd.DataFrame,
778
+ methods_applied: Dict
779
+ ) -> None:
780
+ """Visualise outlier handling results"""
781
+ cols_to_plot = list(methods_applied.keys())[:3]
782
+
783
+ if not cols_to_plot:
784
+ return
785
+
786
+ fig, axes = plt.subplots(len(cols_to_plot), 2, figsize=(14, 4 * len(cols_to_plot)))
787
+ if len(cols_to_plot) == 1:
788
+ axes = axes.reshape(1, -1)
789
+
790
+ for idx, col in enumerate(cols_to_plot):
791
+ if col not in original_data.columns or col not in processed_data.columns:
792
+ continue
793
+
794
+ # Distribution before handling
795
+ axes[idx, 0].hist(original_data[col].dropna(), bins=30, alpha=0.5, label='Before', density=True)
796
+ axes[idx, 0].hist(processed_data[col].dropna(), bins=30, alpha=0.5, label='After', density=True)
797
+ axes[idx, 0].set_title(f'{col}: Distribution before/after')
798
+ axes[idx, 0].set_xlabel('Value')
799
+ axes[idx, 0].set_ylabel('Density')
800
+ axes[idx, 0].legend()
801
+ axes[idx, 0].grid(True, alpha=0.3)
802
+
803
+ # QQ plot for normality check
804
+ stats.probplot(original_data[col].dropna(), dist="norm", plot=axes[idx, 1])
805
+ axes[idx, 1].set_title(f'{col}: Q-Q plot (before handling)')
806
+ axes[idx, 1].grid(True, alpha=0.3)
807
+
808
+ plt.tight_layout()
809
+ plt.savefig(
810
+ f'{self.config.results_dir}/plots/outlier_handling_results.png',
811
+ dpi=300,
812
+ bbox_inches='tight'
813
+ )
814
+ plt.show()
815
+
816
+ def create_validation_rules(self) -> Dict:
817
+ """Create validation rules based on outlier analysis"""
818
+ rules = {}
819
+
820
+ for col, info in self.outlier_info.items():
821
+ outliers_percent = info['outliers_percent']
822
+ skewness = info['statistics']['skewness']
823
+
824
+ rule = {
825
+ 'outliers_percent': outliers_percent,
826
+ 'skewness': skewness,
827
+ 'recommended_action': 'none'
828
+ }
829
+
830
+ if outliers_percent > 10:
831
+ rule['recommended_action'] = 'aggressive_handling'
832
+ rule['reason'] = f'High outliers: {outliers_percent:.1f}%'
833
+ elif outliers_percent > 5:
834
+ rule['recommended_action'] = 'moderate_handling'
835
+ rule['reason'] = f'Moderate outliers: {outliers_percent:.1f}%'
836
+ elif outliers_percent > 1:
837
+ rule['recommended_action'] = 'conservative_handling'
838
+ rule['reason'] = f'Low outliers: {outliers_percent:.1f}%'
839
+
840
+ if abs(skewness) > 1:
841
+ rule['skewness_issue'] = True
842
+ rule['skewness_reason'] = f'Strong skewness: {skewness:.2f}'
843
+ if rule['recommended_action'] == 'none':
844
+ rule['recommended_action'] = 'transformation'
845
+
846
+ rules[col] = rule
847
+
848
+ return rules
849
+
850
+ def get_report(self) -> Dict:
851
+ """Get outlier analysis report"""
852
+ return {
853
+ 'outlier_info': self.outlier_info,
854
+ 'handling_methods': self.handling_methods,
855
+ 'detection_methods': self.detection_methods,
856
+ 'validation_rules': self.create_validation_rules()
857
+ }
pipeline/__init__.py ADDED
File without changes
pipeline/main_pipeline.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 14: MAIN PIPELINE
3
+ # ============================================
4
+ from datetime import datetime
5
+ import json
6
+ import os
7
+ import traceback
8
+ from typing import Any, Dict, Optional
9
+ from venv import logger
10
+ from config.config import Config
11
+ from correlations.correlation_analyzer import CorrelationAnalyzer
12
+ from data_loader.data_loader import DataLoader
13
+ from decomposition.decomposer import TimeSeriesDecomposer
14
+ from feature_selection.feature_selector import FeatureSelector
15
+ from features.feature_engineer import FeatureEngineer
16
+
17
+ from missing_values.missing_analyzer import MissingValueAnalyser
18
+ from outliers.outlier_analyzer import OutlierAnalyser
19
+ from scaling.data_scaler import DataScaler
20
+ from splitting.data_splitter import DataSplitter
21
+ from stationarity.stationarity_checker import StationarityChecker
22
+ from validation.data_validator import DataValidator
23
+ import pandas as pd
24
+ import numpy as np
25
+
26
+ from visualization.visualization_manager import VisualisationManager
27
+
28
+ class EnhancedDataPreprocessingPipeline:
29
+ """Enhanced main data preprocessing pipeline"""
30
+
31
+ def __init__(self, config: Config):
32
+ """
33
+ Initialise pipeline
34
+
35
+ Parameters:
36
+ -----------
37
+ config : Config
38
+ Experiment configuration
39
+ """
40
+ self.config = config
41
+ self.data_loader = DataLoader(config)
42
+ self.missing_analyser = MissingValueAnalyser(config)
43
+ self.outlier_analyser = OutlierAnalyser(config)
44
+ self.feature_engineer = FeatureEngineer(config)
45
+ self.stationarity_checker = StationarityChecker(config)
46
+ self.decomposer = TimeSeriesDecomposer(config)
47
+ self.correlation_analyser = CorrelationAnalyzer(config)
48
+ self.data_splitter = DataSplitter(config)
49
+ self.data_scaler = DataScaler(config)
50
+ self.feature_selector = FeatureSelector(config)
51
+ self.data_validator = DataValidator(config)
52
+ self.visualisation_manager = VisualisationManager(config)
53
+
54
+ self.results = {}
55
+ self.processed_data = None
56
+ self.train_data = None
57
+ self.val_data = None
58
+ self.test_data = None
59
+ self.is_fitted = False
60
+
61
+ def run_full_pipeline(
62
+ self,
63
+ data_path: Optional[str] = None,
64
+ use_synthetic: bool = False,
65
+ save_intermediate: bool = True,
66
+ create_reports: bool = True
67
+ ) -> pd.DataFrame:
68
+ """
69
+ Run enhanced full data preprocessing pipeline
70
+
71
+ Parameters:
72
+ -----------
73
+ data_path : str, optional
74
+ Path to data. If None, uses configuration value.
75
+ use_synthetic : bool
76
+ Use synthetic data for testing
77
+ save_intermediate : bool
78
+ Save intermediate results
79
+ create_reports : bool
80
+ Create reports
81
+
82
+ Returns:
83
+ --------
84
+ pd.DataFrame
85
+ Processed data
86
+ """
87
+ logger.info("\n" + "="*80)
88
+ logger.info("RUNNING ENHANCED DATA PREPROCESSING PIPELINE")
89
+ logger.info("="*80)
90
+
91
+ start_time = datetime.now()
92
+
93
+ try:
94
+ # Step 1: Data loading
95
+ logger.info("\n" + "="*80)
96
+ logger.info("STEP 1: DATA LOADING")
97
+ logger.info("="*80)
98
+
99
+ if use_synthetic:
100
+ data = self.data_loader.create_synthetic_data(
101
+ n_days=365*20,
102
+ trend_strength=0.01,
103
+ noise_std=10,
104
+ include_exogenous=True
105
+ )
106
+ else:
107
+ data = self.data_loader.load_from_csv(
108
+ data_path=data_path,
109
+ parse_dates=['date']
110
+ )
111
+
112
+ # Check for date index
113
+ if not isinstance(data.index, pd.DatetimeIndex):
114
+ logger.warning("Index is not DatetimeIndex, setting...")
115
+ if 'date' in data.columns:
116
+ data = data.set_index('date')
117
+ logger.info("Index set from 'date' column")
118
+
119
+ self.results['data_loading'] = {
120
+ 'shape': list(data.shape),
121
+ 'columns': list(data.columns),
122
+ 'date_range': {
123
+ 'min': data.index.min().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None,
124
+ 'max': data.index.max().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None
125
+ },
126
+ 'is_datetime_index': isinstance(data.index, pd.DatetimeIndex)
127
+ }
128
+
129
+ # Save raw data information
130
+ self.data_loader.save_raw_data_info()
131
+
132
+ # Step 2: Raw data validation
133
+ logger.info("\n" + "="*80)
134
+ logger.info("STEP 2: RAW DATA VALIDATION")
135
+ logger.info("="*80)
136
+
137
+ raw_validation = self.data_validator.validate(
138
+ data, stage='raw', detailed=True
139
+ )
140
+ self.results['raw_validation'] = raw_validation
141
+
142
+ if raw_validation['status'] == 'FAIL':
143
+ logger.warning("⚠ Raw data has critical issues!")
144
+ if not self.config.enable_validation:
145
+ logger.warning("Validation disabled in configuration, continuing processing")
146
+ else:
147
+ logger.error("Pipeline interrupted due to data issues")
148
+ return None
149
+
150
+ # Step 3: Missing values analysis and handling
151
+ logger.info("\n" + "="*80)
152
+ logger.info("STEP 3: MISSING VALUES HANDLING")
153
+ logger.info("="*80)
154
+
155
+ missing_info = self.missing_analyser.analyse(data, detailed=True)
156
+ self.results['missing_analysis'] = missing_info
157
+
158
+ # Handle missing values
159
+ data = self.missing_analyser.handle(
160
+ data,
161
+ method='interpolate',
162
+ strategy='columnwise'
163
+ )
164
+ self.results['missing_handling'] = self.missing_analyser.handling_methods
165
+
166
+ # Step 4: Outlier analysis and handling
167
+ logger.info("\n" + "="*80)
168
+ logger.info("STEP 4: OUTLIER HANDLING")
169
+ logger.info("="*80)
170
+
171
+ outlier_info = self.outlier_analyser.analyse(
172
+ data,
173
+ method=self.config.outlier_method,
174
+ columns=data.select_dtypes(include=[np.number]).columns.tolist()
175
+ )
176
+ self.results['outlier_analysis'] = outlier_info
177
+
178
+ # Handle outliers
179
+ data = self.outlier_analyser.handle(
180
+ data,
181
+ method='clip',
182
+ strategy='columnwise'
183
+ )
184
+ self.results['outlier_handling'] = self.outlier_analyser.handling_methods
185
+
186
+ # Step 5: Feature engineering
187
+ logger.info("\n" + "="*80)
188
+ logger.info("STEP 5: FEATURE ENGINEERING")
189
+ logger.info("="*80)
190
+
191
+ data = self.feature_engineer.create_all_features(data)
192
+ self.results['feature_engineering'] = self.feature_engineer.feature_info
193
+
194
+ # Check for data after feature engineering
195
+ if len(data) == 0:
196
+ logger.error("No data remaining after feature engineering!")
197
+ return None
198
+
199
+ # Step 6: Stationarity analysis
200
+ logger.info("\n" + "="*80)
201
+ logger.info("STEP 6: STATIONARITY ANALYSIS")
202
+ logger.info("="*80)
203
+
204
+ stationarity_results = self.stationarity_checker.check(
205
+ data,
206
+ target_col=self.config.target_column,
207
+ make_stationary=True,
208
+ try_transformations=True
209
+ )
210
+ self.results['stationarity_analysis'] = stationarity_results
211
+
212
+ # Step 7: Time series decomposition
213
+ logger.info("\n" + "="*80)
214
+ logger.info("STEP 7: TIME SERIES DECOMPOSITION")
215
+ logger.info("="*80)
216
+
217
+ if isinstance(data.index, pd.DatetimeIndex) and len(data) > 365:
218
+ decomposition_results = self.decomposer.decompose(
219
+ data,
220
+ target_col=self.config.target_column,
221
+ method='stl',
222
+ period=self.config.seasonal_period
223
+ )
224
+ self.results['decomposition'] = decomposition_results
225
+ else:
226
+ logger.info("Skipped: insufficient data or no DatetimeIndex")
227
+ self.results['decomposition'] = {'skipped': 'insufficient data or no DatetimeIndex'}
228
+
229
+ # Step 8: Correlation analysis
230
+ logger.info("\n" + "="*80)
231
+ logger.info("STEP 8: CORRELATION ANALYSIS")
232
+ logger.info("="*80)
233
+
234
+ corr_matrix = self.correlation_analyser.analyze(
235
+ data,
236
+ target_col=self.config.target_column,
237
+ threshold=0.8,
238
+ detailed=True
239
+ )
240
+ self.results['correlation_analysis'] = self.correlation_analyser.get_report()
241
+
242
+ # Remove highly correlated features
243
+ if not corr_matrix.empty:
244
+ data = self.correlation_analyser.remove_highly_correlated(
245
+ data,
246
+ threshold=0.95,
247
+ method='variance',
248
+ keep_target=True
249
+ )
250
+ else:
251
+ logger.warning("Correlation matrix empty, skipping feature removal")
252
+
253
+ # Step 9: Processed data validation
254
+ logger.info("\n" + "="*80)
255
+ logger.info("STEP 9: PROCESSED DATA VALIDATION")
256
+ logger.info("="*80)
257
+
258
+ processed_validation = self.data_validator.validate(
259
+ data, stage='processed', detailed=True
260
+ )
261
+ self.results['processed_validation'] = processed_validation
262
+
263
+ if processed_validation['status'] == 'FAIL':
264
+ logger.warning("⚠ Processed data failed validation!")
265
+ logger.warning("Continuing pipeline, but data quality may be low")
266
+ elif processed_validation['status'] == 'WARNING':
267
+ logger.warning("⚠ Processed data requires attention")
268
+
269
+ # Step 10: Data splitting
270
+ logger.info("\n" + "="*80)
271
+ logger.info("STEP 10: DATA SPLITTING")
272
+ logger.info("="*80)
273
+
274
+ train_data, val_data, test_data = self.data_splitter.split(
275
+ data,
276
+ method=self.config.split_method,
277
+ test_size=self.config.test_size,
278
+ validation_size=self.config.validation_size
279
+ )
280
+
281
+ self.train_data = train_data
282
+ self.val_data = val_data
283
+ self.test_data = test_data
284
+
285
+ self.results['data_splitting'] = self.data_splitter.split_info
286
+
287
+ # Step 11: Data scaling
288
+ logger.info("\n" + "="*80)
289
+ logger.info("STEP 11: DATA SCALING")
290
+ logger.info("="*80)
291
+
292
+ # Scale training data
293
+ train_data_scaled = self.data_scaler.fit_transform(
294
+ train_data,
295
+ method=self.config.scaling_method,
296
+ target_col=self.config.target_column,
297
+ fit_on_train=True
298
+ )
299
+
300
+ # Apply same scaling to validation and test data
301
+ val_data_scaled = self.data_scaler.transform(val_data)
302
+ test_data_scaled = self.data_scaler.transform(test_data)
303
+
304
+ self.train_data = train_data_scaled
305
+ self.val_data = val_data_scaled
306
+ self.test_data = test_data_scaled
307
+
308
+ self.results['data_scaling'] = self.data_scaler.get_report()
309
+
310
+ # Step 12: Feature selection
311
+ logger.info("\n" + "="*80)
312
+ logger.info("STEP 12: FEATURE SELECTION")
313
+ logger.info("="*80)
314
+
315
+ if len(train_data_scaled.columns) > 5:
316
+ # Select features on training data
317
+ train_data_selected = self.feature_selector.select(
318
+ train_data_scaled,
319
+ method=self.config.feature_selection_method,
320
+ n_features=min(self.config.max_features, len(train_data_scaled.columns) - 1)
321
+ )
322
+
323
+ # Save selected features
324
+ selected_features = self.feature_selector.selected_features
325
+
326
+ # Apply same selection to validation and test data
327
+ features_to_keep = selected_features + [self.config.target_column]
328
+ features_to_keep = [f for f in features_to_keep if f in val_data_scaled.columns]
329
+
330
+ if len(features_to_keep) > 1:
331
+ self.train_data = train_data_scaled[features_to_keep].copy()
332
+ self.val_data = val_data_scaled[features_to_keep].copy()
333
+ self.test_data = test_data_scaled[features_to_keep].copy()
334
+ else:
335
+ logger.warning("Failed to select features, using all")
336
+
337
+ self.results['feature_selection'] = self.feature_selector.get_report()
338
+ else:
339
+ logger.info("Skipped: insufficient features for selection")
340
+ self.results['feature_selection'] = {'skipped': 'insufficient features'}
341
+
342
+ # Step 13: Final validation
343
+ logger.info("\n" + "="*80)
344
+ logger.info("STEP 13: FINAL VALIDATION")
345
+ logger.info("="*80)
346
+
347
+ # Combine all data for final validation
348
+ all_processed_data = pd.concat([self.train_data, self.val_data, self.test_data])
349
+
350
+ final_validation = self.data_validator.validate(
351
+ all_processed_data, stage='final', detailed=True
352
+ )
353
+ self.results['final_validation'] = final_validation
354
+
355
+ self.processed_data = all_processed_data
356
+ self.is_fitted = True
357
+
358
+ # Step 14: Additional multicollinearity cleaning
359
+ logger.info("\n" + "="*80)
360
+ logger.info("STEP 14: ADDITIONAL MULTICOLLINEARITY CLEANING")
361
+ logger.info("="*80)
362
+
363
+ # Remove temporal features with extreme VIF
364
+ self.processed_data = self._remove_extreme_vif_features(self.processed_data)
365
+ self.train_data = self.train_data[self.processed_data.columns]
366
+ self.val_data = self.val_data[self.processed_data.columns]
367
+ self.test_data = self.test_data[self.processed_data.columns]
368
+
369
+ # Step 15: Create visualisations and reports
370
+ logger.info("\n" + "="*80)
371
+ logger.info("STEP 15: CREATING REPORTS AND VISUALISATIONS")
372
+ logger.info("="*80)
373
+
374
+ if create_reports:
375
+ self.create_all_reports()
376
+ self.create_all_visualisations()
377
+
378
+ # Calculate execution time
379
+ execution_time = (datetime.now() - start_time).total_seconds()
380
+
381
+ # Save final results
382
+ self.results['pipeline_execution'] = {
383
+ 'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'),
384
+ 'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
385
+ 'execution_time_seconds': execution_time,
386
+ 'success': True,
387
+ 'stages_completed': 15
388
+ }
389
+
390
+ # Save configuration and results
391
+ self.save_pipeline_results()
392
+
393
+ logger.info("\n" + "="*80)
394
+ logger.info("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
395
+ logger.info("="*80)
396
+ logger.info(f"Execution time: {execution_time:.2f} seconds")
397
+ logger.info(f"Initial data size: {self.results['data_loading']['shape']}")
398
+ logger.info(f"Final data size: {list(self.processed_data.shape)}")
399
+ logger.info(f"Data quality: {final_validation['overall_score']}/100")
400
+ logger.info(f"Status: {final_validation['status']}")
401
+ logger.info(f"Training data: {len(self.train_data)} records")
402
+ logger.info(f"Features in final set: {len(self.train_data.columns)}")
403
+
404
+ return self.processed_data
405
+
406
+ except Exception as e:
407
+ logger.error(f"✗ Pipeline error: {e}")
408
+ logger.error(traceback.format_exc())
409
+
410
+ self.results['pipeline_execution'] = {
411
+ 'success': False,
412
+ 'error': str(e),
413
+ 'traceback': traceback.format_exc()
414
+ }
415
+
416
+ # Save partial results
417
+ self.save_pipeline_results()
418
+
419
+ return None
420
+
421
+ def _remove_extreme_vif_features(self, data: pd.DataFrame) -> pd.DataFrame:
422
+ """Remove features with extreme VIF"""
423
+ data_clean = data.copy()
424
+
425
+ # Identify features with extreme VIF for removal
426
+ extreme_vif_features = [
427
+ 'year', 'day', 'dayofyear', 'days_from_start',
428
+ 'raskhodvoda_zscore' # Usually has extreme VIF
429
+ ]
430
+
431
+ # Remove only those present in data
432
+ features_to_remove = [f for f in extreme_vif_features if f in data_clean.columns]
433
+
434
+ if features_to_remove:
435
+ logger.info(f"Removing features with extreme VIF: {features_to_remove}")
436
+ data_clean = data_clean.drop(columns=features_to_remove)
437
+
438
+ return data_clean
439
+
440
+ def create_all_reports(self) -> None:
441
+ """Create all reports"""
442
+ logger.info("Creating reports...")
443
+
444
+ # 1. Save validation results
445
+ for stage in ['raw', 'processed', 'final']:
446
+ if stage in self.data_validator.validation_results:
447
+ self.data_validator.save_report(stage)
448
+
449
+ # 2. Save plots information
450
+ self.visualisation_manager.save_plots_info()
451
+
452
+ # 3. Create summary report
453
+ self.create_summary_report()
454
+
455
+ logger.info("✓ All reports created")
456
+
457
+ def create_all_visualisations(self) -> None:
458
+ """Create all visualisations"""
459
+ logger.info("Creating visualisations...")
460
+
461
+ if self.processed_data is not None:
462
+ # 1. Summary dashboard
463
+ preprocessing_stages = {
464
+ 'Loading': self.results['data_loading']['shape'][1] if 'data_loading' in self.results else 0,
465
+ 'After cleaning': len(self.processed_data.columns),
466
+ 'Features created': self.feature_engineer.feature_info.get('features_created', 0),
467
+ 'Features selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
468
+ }
469
+
470
+ self.visualisation_manager.create_summary_dashboard(
471
+ self.processed_data,
472
+ preprocessing_stages
473
+ )
474
+
475
+ logger.info("✓ All visualisations created")
476
+
477
+ def create_summary_report(self) -> None:
478
+ """Create summary report"""
479
+ report = {
480
+ 'pipeline_summary': {
481
+ 'config': self.config.to_dict(),
482
+ 'execution': self.results.get('pipeline_execution', {}),
483
+ 'data_statistics': {
484
+ 'initial_shape': self.results.get('data_loading', {}).get('shape', []),
485
+ 'final_shape': list(self.processed_data.shape) if self.processed_data is not None else [],
486
+ 'target_column': self.config.target_column,
487
+ 'features_created': self.feature_engineer.feature_info.get('features_created', 0),
488
+ 'features_selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
489
+ }
490
+ },
491
+ 'validation_summary': {},
492
+ 'quality_metrics': {}
493
+ }
494
+
495
+ # Add validation results
496
+ for stage in ['raw', 'processed', 'final']:
497
+ if stage in self.data_validator.validation_results:
498
+ stage_results = self.data_validator.validation_results[stage]
499
+ report['validation_summary'][stage] = {
500
+ 'status': stage_results.get('status'),
501
+ 'score': stage_results.get('overall_score'),
502
+ 'issues_count': sum(len(issues) for issues in stage_results.get('issues', {}).values()),
503
+ 'checks_passed': sum(1 for check in stage_results.get('basic_checks', {}).values()
504
+ if check.get('passed', False))
505
+ }
506
+
507
+ # Save report
508
+ report_path = f'{self.config.results_dir}/reports/pipeline_summary.json'
509
+
510
+ with open(report_path, 'w', encoding='utf-8') as f:
511
+ json.dump(report, f, indent=4, ensure_ascii=False)
512
+
513
+ logger.info(f"✓ Summary report saved: {report_path}")
514
+
515
+ def save_pipeline_results(self) -> None:
516
+ """Save all pipeline results"""
517
+ # Save configuration
518
+ self.config.save()
519
+
520
+ # Save data
521
+ if self.processed_data is not None:
522
+ # Save processed data
523
+ data_path = f'{self.config.results_dir}/processed_data/processed_data.csv'
524
+ self.processed_data.to_csv(data_path)
525
+ logger.info(f"✓ Processed data saved: {data_path}")
526
+
527
+ # Save split data
528
+ if self.train_data is not None:
529
+ self.train_data.to_csv(f'{self.config.results_dir}/processed_data/train_data.csv')
530
+ self.val_data.to_csv(f'{self.config.results_dir}/processed_data/val_data.csv')
531
+ self.test_data.to_csv(f'{self.config.results_dir}/processed_data/test_data.csv')
532
+
533
+ def get_final_data_for_modelling(self) -> Dict[str, Any]:
534
+ """Prepare data for modelling"""
535
+ if not self.is_fitted:
536
+ logger.warning("Pipeline not executed, data not ready")
537
+ return {}
538
+
539
+ return {
540
+ 'X_train': self.train_data.drop(columns=[self.config.target_column]),
541
+ 'y_train': self.train_data[self.config.target_column],
542
+ 'X_val': self.val_data.drop(columns=[self.config.target_column]),
543
+ 'y_val': self.val_data[self.config.target_column],
544
+ 'X_test': self.test_data.drop(columns=[self.config.target_column]),
545
+ 'y_test': self.test_data[self.config.target_column],
546
+ 'feature_names': self.train_data.drop(columns=[self.config.target_column]).columns.tolist(),
547
+ 'scaler': self.data_scaler,
548
+ 'feature_selector': self.feature_selector,
549
+ 'results': self.results
550
+ }
551
+
552
+
553
+ # ============================================
554
+ # QUICK LAUNCH FUNCTION
555
+ # ============================================
556
+ def run_enhanced_preprocessing(
557
+ config_path: Optional[str] = None,
558
+ data_path: Optional[str] = None,
559
+ use_synthetic: bool = False,
560
+ save_results: bool = True
561
+ ) -> EnhancedDataPreprocessingPipeline:
562
+ """
563
+ Quick launch function for enhanced pipeline
564
+
565
+ Parameters:
566
+ -----------
567
+ config_path : str, optional
568
+ Path to configuration file
569
+ data_path : str, optional
570
+ Path to data
571
+ use_synthetic : bool
572
+ Use synthetic data
573
+ save_results : bool
574
+ Save results
575
+
576
+ Returns:
577
+ --------
578
+ EnhancedDataPreprocessingPipeline
579
+ Pipeline object with results
580
+ """
581
+ # Load or create configuration
582
+ if config_path and os.path.exists(config_path):
583
+ config = Config.load(config_path)
584
+ logger.info(f"Configuration loaded from {config_path}")
585
+ else:
586
+ config = Config()
587
+ logger.info("Using default configuration")
588
+
589
+ # Update data path if specified
590
+ if data_path:
591
+ config.data_path = data_path
592
+
593
+ # Create and run pipeline
594
+ pipeline = EnhancedDataPreprocessingPipeline(config)
595
+
596
+ pipeline.run_full_pipeline(
597
+ data_path=data_path,
598
+ use_synthetic=use_synthetic,
599
+ save_intermediate=save_results,
600
+ create_reports=save_results
601
+ )
602
+
603
+ return pipeline
requirements.txt CHANGED
@@ -1,9 +1,100 @@
1
- streamlit
2
- pandas
3
- numpy
4
- plotly
5
- joblib
6
- huggingface_hub
7
- scikit-learn
8
- openpyxl
9
- xgboost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ altair==6.0.0
3
+ anyio==4.11.0
4
+ astunparse==1.6.3
5
+ attrs==25.4.0
6
+ blinker==1.9.0
7
+ cachetools==6.2.4
8
+ certifi==2025.11.12
9
+ charset-normalizer==3.4.4
10
+ click==8.3.1
11
+ colorama==0.4.6
12
+ contourpy==1.3.2
13
+ cycler==0.12.1
14
+ et_xmlfile==2.0.0
15
+ filelock==3.20.1
16
+ flatbuffers==25.12.19
17
+ fonttools==4.61.1
18
+ fsspec==2025.12.0
19
+ gast==0.7.0
20
+ gensim==4.4.0
21
+ gitdb==4.0.12
22
+ GitPython==3.1.46
23
+ google-pasta==0.2.0
24
+ grpcio==1.76.0
25
+ h5py==3.15.1
26
+ h11==0.16.0
27
+ hf-xet==1.2.0
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ huggingface_hub==1.1.2
31
+ idna==3.11
32
+ Jinja2==3.1.6
33
+ joblib==1.5.3
34
+ jsonschema==4.25.1
35
+ jsonschema-specifications==2025.9.1
36
+ keras==3.13.0
37
+ kiwisolver==1.4.9
38
+ libclang==18.1.1
39
+ Markdown==3.10
40
+ markdown-it-py==4.0.0
41
+ MarkupSafe==3.0.3
42
+ matplotlib==3.10.8
43
+ mdurl==0.1.2
44
+ ml_dtypes==0.5.4
45
+ mpmath==1.3.0
46
+ namex==0.1.0
47
+ narwhals==2.14.0
48
+ networkx==3.6.1
49
+ numpy==2.4.0
50
+ openpyxl==3.1.5
51
+ opt_einsum==3.4.0
52
+ optree==0.18.0
53
+ packaging==25.0
54
+ pandas==2.3.3
55
+ patsy==1.0.2
56
+ pillow==12.0.0
57
+ plotly==6.5.0
58
+ protobuf==6.33.2
59
+ pyarrow==22.0.0
60
+ pydeck==0.9.1
61
+ Pygments==2.19.2
62
+ pyparsing==3.3.1
63
+ pyperclip==1.11.0
64
+ python-dateutil==2.9.0.post0
65
+ pytz==2025.2
66
+ PyYAML==6.0.3
67
+ referencing==0.37.0
68
+ requests==2.32.5
69
+ rich==14.2.0
70
+ rpds-py==0.30.0
71
+ scikit-learn==1.8.0
72
+ scipy==1.16.3
73
+ seaborn==0.13.2
74
+ setuptools==80.9.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ smart_open==7.5.0
78
+ smmap==5.0.2
79
+ sniffio==1.3.1
80
+ statsmodels==0.14.6
81
+ streamlit==1.52.2
82
+ sympy==1.14.0
83
+ tenacity==9.1.2
84
+ tensorboard==2.20.0
85
+ tensorboard-data-server==0.7.2
86
+ tensorflow==2.20.0
87
+ termcolor==3.3.0
88
+ threadpoolctl==3.6.0
89
+ toml==0.10.2
90
+ torch==2.9.1
91
+ tornado==6.5.4
92
+ tqdm==4.67.1
93
+ typer-slim==0.20.0
94
+ typing_extensions==4.15.0
95
+ tzdata==2025.3
96
+ urllib3==2.6.2
97
+ watchdog==6.0.0
98
+ Werkzeug==3.1.4
99
+ wheel==0.45.1
100
+ wrapt==2.0.1
run_pipeline.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # RUN
3
+ # ============================================
4
+ from config.config import Config
5
+ from pipeline.main_pipeline import EnhancedDataPreprocessingPipeline
6
+ import pandas as pd
7
+
8
+ if __name__ == "__main__":
9
+ """
10
+ Pipeline execution
11
+ """
12
+
13
+ # Configuration with reasonable parameters
14
+ config = Config(
15
+ data_path='temp_data.csv',
16
+ results_dir='enhanced_preprocessing_results',
17
+ target_column='raskhodvoda',
18
+ start_year=1970,
19
+ end_year=1990,
20
+ max_lags=5,
21
+ seasonal_period=365,
22
+ rolling_windows=[7, 30, 90],
23
+ expanding_windows=[30, 90],
24
+ test_size=0.2,
25
+ validation_size=0.1,
26
+ scaling_method='robust',
27
+ feature_selection_method='correlation',
28
+ max_features=20,
29
+ missing_threshold=0.3,
30
+ outlier_method='iqr',
31
+ enable_validation=True
32
+ )
33
+
34
+ # Run enhanced pipeline
35
+ pipeline = EnhancedDataPreprocessingPipeline(config)
36
+ processed_data = pipeline.run_full_pipeline(
37
+ use_synthetic=False,
38
+ save_intermediate=True,
39
+ create_reports=True
40
+ )
41
+
42
+ if processed_data is not None:
43
+ print("\n" + "="*80)
44
+ print("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
45
+ print("="*80)
46
+ print(f"Final data size: {processed_data.shape}")
47
+ print(f"Columns: {list(processed_data.columns)}")
48
+
49
+ # Get modeling data
50
+ modeling_data = pipeline.get_final_data_for_modeling()
51
+
52
+ if modeling_data:
53
+ print(f"\nModeling data ready:")
54
+ print(f" X_train: {modeling_data['X_train'].shape}")
55
+ print(f" X_val: {modeling_data['X_val'].shape}")
56
+ print(f" X_test: {modeling_data['X_test'].shape}")
57
+ print(f" Features: {len(modeling_data['feature_names'])}")
58
+
59
+ # Save final data
60
+ processed_data.to_csv('enhanced_preprocessing_results\processed_data\enhanced_final_processed_data.csv',
61
+ index=True if isinstance(processed_data.index, pd.DatetimeIndex) else False)
62
+ print(f"\n✓ Final data saved to 'enhanced_final_processed_data.csv'")
scaling/__init__.py ADDED
File without changes
scaling/data_scaler.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 10: DATA SCALING
3
+ # ============================================
4
+ from typing import Dict, List, Optional, Tuple
5
+ from venv import logger
6
+ import pandas as pd
7
+ from config.config import Config
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+ class DataScaler:
13
+ """Class for data scaling and normalisation"""
14
+
15
+ def __init__(self, config: Config):
16
+ """
17
+ Initialise scaler
18
+
19
+ Parameters:
20
+ -----------
21
+ config : Config
22
+ Experiment configuration
23
+ """
24
+ self.config = config
25
+ self.scalers = {}
26
+ self.scaling_info = {}
27
+ self.transforms_applied = {}
28
+
29
+ def fit_transform(
30
+ self,
31
+ data: pd.DataFrame,
32
+ method: str = None,
33
+ columns: List[str] = None,
34
+ target_col: Optional[str] = None,
35
+ fit_on_train: bool = True,
36
+ **kwargs
37
+ ) -> pd.DataFrame:
38
+ """
39
+ Scale data
40
+
41
+ Parameters:
42
+ -----------
43
+ data : pd.DataFrame
44
+ Input data
45
+ method : str, optional
46
+ Scaling method. If None, uses configuration value.
47
+ columns : List[str], optional
48
+ List of columns to scale. If None, uses all numeric columns.
49
+ target_col : str, optional
50
+ Target variable (not scaled by default)
51
+ fit_on_train : bool
52
+ Whether to save scaling parameters for applying to new data
53
+ **kwargs : dict
54
+ Additional parameters for method
55
+
56
+ Returns:
57
+ --------
58
+ pd.DataFrame
59
+ Scaled data
60
+ """
61
+ logger.info("\n" + "="*80)
62
+ logger.info("DATA SCALING")
63
+ logger.info("="*80)
64
+
65
+ method = method or self.config.scaling_method
66
+ data_scaled = data.copy()
67
+
68
+ if columns is None:
69
+ # Select all numeric columns except target
70
+ numeric_cols = data_scaled.select_dtypes(include=[np.number]).columns
71
+ if target_col and target_col in numeric_cols:
72
+ columns = [col for col in numeric_cols if col != target_col]
73
+ else:
74
+ columns = list(numeric_cols)
75
+
76
+ logger.info(f"Scaling method: {method}")
77
+ logger.info(f"Columns to process: {len(columns)}")
78
+
79
+ # Apply scaling
80
+ for col in columns:
81
+ if col in data_scaled.columns:
82
+ try:
83
+ # Check feature type
84
+ series = data_scaled[col].dropna()
85
+
86
+ # Special handling for different feature types
87
+ if self._is_binary_feature(series):
88
+ logger.debug(f" {col}: binary feature, scaling not applied")
89
+ scaler_info = {
90
+ 'method': 'none',
91
+ 'scaler_type': 'binary',
92
+ 'original_values': sorted(series.unique().tolist()),
93
+ 'note': 'binary feature, no scaling applied'
94
+ }
95
+ self.scaling_info[col] = scaler_info
96
+
97
+ if fit_on_train:
98
+ self.scalers[col] = scaler_info
99
+
100
+ elif self._is_categorical_feature(series):
101
+ logger.debug(f" {col}: categorical feature, using min-max")
102
+ scaled_series, scaler_info = self._apply_scaling(
103
+ data_scaled[col], 'minmax', fit_on_train, **kwargs
104
+ )
105
+ data_scaled[col] = scaled_series
106
+ self.scaling_info[col] = scaler_info
107
+
108
+ if fit_on_train:
109
+ if scaler_info.get('scaler_type') == 'sklearn':
110
+ self.scalers[col] = scaler_info['scaler_object']
111
+ else:
112
+ self.scalers[col] = scaler_info
113
+
114
+ else:
115
+ # Regular scaling for continuous features
116
+ scaled_series, scaler_info = self._apply_scaling(
117
+ data_scaled[col], method, fit_on_train, **kwargs
118
+ )
119
+ data_scaled[col] = scaled_series
120
+ self.scaling_info[col] = scaler_info
121
+
122
+ if fit_on_train:
123
+ if scaler_info.get('scaler_type') == 'sklearn':
124
+ self.scalers[col] = scaler_info['scaler_object']
125
+ else:
126
+ self.scalers[col] = scaler_info
127
+
128
+ except Exception as e:
129
+ logger.warning(f"Error processing column {col}: {e}")
130
+ # Save error information
131
+ self.scaling_info[col] = {
132
+ 'method': 'error',
133
+ 'error': str(e),
134
+ 'scaler_type': 'none'
135
+ }
136
+
137
+ logger.info(f"✓ Data processed using {method} method")
138
+
139
+ # Visualisation of results
140
+ if self.config.save_plots and columns:
141
+ self._plot_scaling_results(data, data_scaled, columns, method)
142
+
143
+ return data_scaled
144
+
145
+ def _is_binary_feature(self, series: pd.Series) -> bool:
146
+ """Check if feature is binary"""
147
+ unique_values = series.dropna().unique()
148
+ return len(unique_values) == 2 and set(unique_values).issubset({0, 1})
149
+
150
+ def _is_categorical_feature(self, series: pd.Series, max_categories: int = 10) -> bool:
151
+ """Check if feature is categorical"""
152
+ unique_values = series.dropna().unique()
153
+ return len(unique_values) <= max_categories and series.dtype in ['int64', 'float64']
154
+
155
+ def _apply_scaling(
156
+ self,
157
+ series: pd.Series,
158
+ method: str,
159
+ fit_on_train: bool,
160
+ **kwargs
161
+ ) -> Tuple[pd.Series, Dict]:
162
+ """Apply specific scaling method"""
163
+ series_clean = series.dropna()
164
+
165
+ if len(series_clean) == 0:
166
+ return series, {
167
+ 'method': 'none',
168
+ 'scaler_type': 'none',
169
+ 'error': 'all values are NaN'
170
+ }
171
+
172
+ scaler_info = {
173
+ 'method': method,
174
+ 'scaler_type': 'simple',
175
+ 'original_mean': float(series_clean.mean()),
176
+ 'original_std': float(series_clean.std()),
177
+ 'original_min': float(series_clean.min()),
178
+ 'original_max': float(series_clean.max()),
179
+ 'scaler': None,
180
+ 'scaler_object': None
181
+ }
182
+
183
+ try:
184
+ if method == 'standard':
185
+ # Standardisation (z-score normalisation)
186
+ mean = series_clean.mean()
187
+ std = series_clean.std()
188
+
189
+ if std > 0:
190
+ series_scaled = (series - mean) / std
191
+ scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
192
+ else:
193
+ series_scaled = series - mean # If std = 0, just center
194
+ scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
195
+
196
+ elif method == 'minmax':
197
+ # Min-Max normalisation
198
+ min_val = series_clean.min()
199
+ max_val = series_clean.max()
200
+
201
+ if max_val > min_val:
202
+ series_scaled = (series - min_val) / (max_val - min_val)
203
+ scaler_info['scaler'] = {'min': float(min_val), 'max': float(max_val)}
204
+ else:
205
+ series_scaled = series - min_val # If all values equal
206
+ scaler_info['scaler'] = {'min': float(min_val), 'max': float(min_val)}
207
+
208
+ elif method == 'robust':
209
+ # Robust scaling (outlier resistant)
210
+ # Check sufficient values for quartile calculation
211
+ if len(series_clean) >= 4:
212
+ median = series_clean.median()
213
+ q1 = series_clean.quantile(0.25)
214
+ q3 = series_clean.quantile(0.75)
215
+ iqr = q3 - q1
216
+
217
+ if iqr > 0:
218
+ series_scaled = (series - median) / iqr
219
+ scaler_info['scaler'] = {
220
+ 'median': float(median),
221
+ 'q1': float(q1),
222
+ 'q3': float(q3),
223
+ 'iqr': float(iqr)
224
+ }
225
+ else:
226
+ # If IQR = 0, use standard deviation
227
+ std = series_clean.std()
228
+ if std > 0:
229
+ series_scaled = (series - median) / std
230
+ scaler_info['scaler'] = {'median': float(median), 'std': float(std)}
231
+ else:
232
+ series_scaled = series - median
233
+ scaler_info['scaler'] = {'median': float(median), 'iqr': 0}
234
+ else:
235
+ # If insufficient data, use standardisation
236
+ mean = series_clean.mean()
237
+ std = series_clean.std()
238
+ if std > 0:
239
+ series_scaled = (series - mean) / std
240
+ scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
241
+ scaler_info['method'] = 'standard' # Change method in info
242
+ else:
243
+ series_scaled = series - mean
244
+ scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
245
+ scaler_info['method'] = 'standard'
246
+
247
+ elif method == 'log':
248
+ # Logarithmic transformation
249
+ min_val = series_clean.min()
250
+
251
+ if min_val <= 0:
252
+ shift = abs(min_val) + 1
253
+ series_scaled = np.log(series + shift)
254
+ scaler_info['scaler'] = {'shift': float(shift)}
255
+ else:
256
+ series_scaled = np.log(series)
257
+ scaler_info['scaler'] = {'shift': 0}
258
+
259
+ elif method == 'boxcox':
260
+ # Box-Cox transformation
261
+ try:
262
+ from scipy.stats import boxcox
263
+
264
+ min_val = series_clean.min()
265
+ if min_val <= 0:
266
+ shift = abs(min_val) + 1
267
+ series_to_transform = series + shift
268
+ else:
269
+ shift = 0
270
+ series_to_transform = series
271
+
272
+ transformed, lambda_val = boxcox(series_to_transform.dropna())
273
+
274
+ # Interpolate for all values
275
+ series_scaled = series.copy()
276
+ valid_mask = series_to_transform.notna()
277
+ series_scaled[valid_mask] = transformed
278
+
279
+ scaler_info['scaler'] = {
280
+ 'lambda': float(lambda_val),
281
+ 'shift': float(shift)
282
+ }
283
+
284
+ except Exception as e:
285
+ logger.warning(f"Box-Cox transformation failed for {series.name}: {e}")
286
+ # Return original series and change method
287
+ series_scaled = series
288
+ scaler_info['method'] = 'none'
289
+ scaler_info['scaler_type'] = 'none'
290
+ scaler_info['error'] = str(e)
291
+
292
+ elif method == 'quantile':
293
+ # Quantile transformation (rank-based)
294
+ try:
295
+ from sklearn.preprocessing import QuantileTransformer
296
+
297
+ qt = QuantileTransformer(
298
+ n_quantiles=kwargs.get('n_quantiles', min(100, len(series_clean))),
299
+ output_distribution=kwargs.get('output_distribution', 'normal'),
300
+ random_state=kwargs.get('random_state', 42)
301
+ )
302
+
303
+ series_reshaped = series.values.reshape(-1, 1)
304
+ series_scaled_values = qt.fit_transform(series_reshaped)
305
+ series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
306
+
307
+ scaler_info['scaler_type'] = 'sklearn'
308
+ scaler_info['scaler_object'] = qt
309
+
310
+ except Exception as e:
311
+ logger.warning(f"Quantile transform failed for {series.name}: {e}")
312
+ series_scaled = series
313
+ scaler_info['method'] = 'none'
314
+ scaler_info['scaler_type'] = 'none'
315
+ scaler_info['error'] = str(e)
316
+
317
+ elif method == 'power':
318
+ # Power transform (Yeo-Johnson)
319
+ try:
320
+ from sklearn.preprocessing import PowerTransformer
321
+
322
+ pt = PowerTransformer(method='yeo-johnson', standardize=True)
323
+
324
+ series_reshaped = series.values.reshape(-1, 1)
325
+ series_scaled_values = pt.fit_transform(series_reshaped)
326
+ series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
327
+
328
+ scaler_info['scaler_type'] = 'sklearn'
329
+ scaler_info['scaler_object'] = pt
330
+
331
+ except Exception as e:
332
+ logger.warning(f"Power transform failed for {series.name}: {e}")
333
+ series_scaled = series
334
+ scaler_info['method'] = 'none'
335
+ scaler_info['scaler_type'] = 'none'
336
+ scaler_info['error'] = str(e)
337
+
338
+ elif method == 'none':
339
+ # No scaling
340
+ series_scaled = series
341
+ scaler_info['method'] = 'none'
342
+ scaler_info['scaler_type'] = 'none'
343
+
344
+ else:
345
+ logger.warning(f"Unknown scaling method: {method}, using standardisation")
346
+ return self._apply_scaling(series, 'standard', fit_on_train, **kwargs)
347
+
348
+ # Add statistics after scaling
349
+ scaled_clean = series_scaled.dropna()
350
+ if len(scaled_clean) > 0:
351
+ scaler_info.update({
352
+ 'scaled_mean': float(scaled_clean.mean()),
353
+ 'scaled_std': float(scaled_clean.std()),
354
+ 'scaled_min': float(scaled_clean.min()),
355
+ 'scaled_max': float(scaled_clean.max()),
356
+ 'skewness_before': float(series_clean.skew()),
357
+ 'skewness_after': float(scaled_clean.skew()),
358
+ 'kurtosis_before': float(series_clean.kurtosis()),
359
+ 'kurtosis_after': float(scaled_clean.kurtosis())
360
+ })
361
+
362
+ return series_scaled, scaler_info
363
+
364
+ except Exception as e:
365
+ logger.warning(f"Error applying method {method} for {series.name}: {e}")
366
+ return series, {
367
+ 'method': 'error',
368
+ 'scaler_type': 'none',
369
+ 'error': str(e)
370
+ }
371
+
372
+ def transform(
373
+ self,
374
+ data: pd.DataFrame,
375
+ columns: List[str] = None
376
+ ) -> pd.DataFrame:
377
+ """
378
+ Apply saved scaling to new data
379
+
380
+ Parameters:
381
+ -----------
382
+ data : pd.DataFrame
383
+ New data
384
+ columns : List[str], optional
385
+ List of columns to transform
386
+
387
+ Returns:
388
+ --------
389
+ pd.DataFrame
390
+ Transformed data
391
+ """
392
+ if not self.scalers:
393
+ logger.warning("Scalers not trained, use fit_transform first")
394
+ return data
395
+
396
+ data_transformed = data.copy()
397
+
398
+ if columns is None:
399
+ columns = [col for col in self.scalers.keys() if col in data_transformed.columns]
400
+
401
+ for col in columns:
402
+ if col in data_transformed.columns and col in self.scalers:
403
+ try:
404
+ scaler_info = self.scaling_info.get(col, {})
405
+ scaler_data = self.scalers[col]
406
+ method = scaler_info.get('method', 'unknown')
407
+
408
+ # For binary features, do nothing
409
+ if method == 'none' and scaler_info.get('scaler_type') == 'binary':
410
+ continue
411
+
412
+ # Skip errors
413
+ if method == 'error':
414
+ continue
415
+
416
+ if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
417
+ scaler_params = scaler_data['scaler']
418
+
419
+ if method == 'standard':
420
+ mean = scaler_params.get('mean', 0)
421
+ std = scaler_params.get('std', 1)
422
+ if std > 0:
423
+ data_transformed[col] = (data_transformed[col] - mean) / std
424
+
425
+ elif method == 'minmax':
426
+ min_val = scaler_params.get('min', 0)
427
+ max_val = scaler_params.get('max', 1)
428
+ if max_val > min_val:
429
+ data_transformed[col] = (data_transformed[col] - min_val) / (max_val - min_val)
430
+
431
+ elif method == 'robust':
432
+ median = scaler_params.get('median', 0)
433
+ iqr = scaler_params.get('iqr', 1)
434
+ if iqr > 0:
435
+ data_transformed[col] = (data_transformed[col] - median) / iqr
436
+ else:
437
+ std = scaler_params.get('std', 1)
438
+ if std > 0:
439
+ data_transformed[col] = (data_transformed[col] - median) / std
440
+
441
+ elif hasattr(scaler_data, 'transform'):
442
+ # For sklearn objects
443
+ from sklearn.base import BaseEstimator
444
+ if isinstance(scaler_data, BaseEstimator):
445
+ try:
446
+ transformed = scaler_data.transform(
447
+ data_transformed[[col]].values.reshape(-1, 1)
448
+ ).flatten()
449
+ data_transformed[col] = transformed
450
+ except Exception as e:
451
+ logger.warning(f"Error in sklearn transformation for {col}: {e}")
452
+
453
+ except Exception as e:
454
+ logger.warning(f"Error transforming column {col}: {e}")
455
+
456
+ return data_transformed
457
+
458
+ def inverse_transform(
459
+ self,
460
+ data: pd.DataFrame,
461
+ columns: List[str] = None
462
+ ) -> pd.DataFrame:
463
+ """
464
+ Inverse transform scaled data
465
+
466
+ Parameters:
467
+ -----------
468
+ data : pd.DataFrame
469
+ Scaled data
470
+ columns : List[str], optional
471
+ List of columns for inverse transform
472
+
473
+ Returns:
474
+ --------
475
+ pd.DataFrame
476
+ Data in original scale
477
+ """
478
+ if not self.scalers:
479
+ logger.warning("Scalers not trained")
480
+ return data
481
+
482
+ data_inverse = data.copy()
483
+
484
+ if columns is None:
485
+ columns = [col for col in self.scalers.keys() if col in data_inverse.columns]
486
+
487
+ for col in columns:
488
+ if col in data_inverse.columns and col in self.scalers:
489
+ try:
490
+ scaler_info = self.scaling_info.get(col, {})
491
+ scaler_data = self.scalers[col]
492
+ method = scaler_info.get('method', 'unknown')
493
+
494
+ # For binary and categorical features, do nothing
495
+ if method in ['none', 'error']:
496
+ continue
497
+
498
+ if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
499
+ scaler_params = scaler_data['scaler']
500
+
501
+ if method == 'standard':
502
+ mean = scaler_params.get('mean', 0)
503
+ std = scaler_params.get('std', 1)
504
+ data_inverse[col] = data_inverse[col] * std + mean
505
+
506
+ elif method == 'minmax':
507
+ min_val = scaler_params.get('min', 0)
508
+ max_val = scaler_params.get('max', 1)
509
+ if max_val > min_val:
510
+ data_inverse[col] = data_inverse[col] * (max_val - min_val) + min_val
511
+
512
+ elif method == 'robust':
513
+ median = scaler_params.get('median', 0)
514
+ iqr = scaler_params.get('iqr', 1)
515
+ if iqr > 0:
516
+ data_inverse[col] = data_inverse[col] * iqr + median
517
+ else:
518
+ std = scaler_params.get('std', 1)
519
+ if std > 0:
520
+ data_inverse[col] = data_inverse[col] * std + median
521
+
522
+ elif hasattr(scaler_data, 'inverse_transform'):
523
+ # For sklearn objects
524
+ from sklearn.base import BaseEstimator
525
+ if isinstance(scaler_data, BaseEstimator):
526
+ try:
527
+ inverse_transformed = scaler_data.inverse_transform(
528
+ data_inverse[[col]].values.reshape(-1, 1)
529
+ ).flatten()
530
+ data_inverse[col] = inverse_transformed
531
+ except Exception as e:
532
+ logger.warning(f"Error in sklearn inverse transformation for {col}: {e}")
533
+
534
+ except Exception as e:
535
+ logger.warning(f"Error in inverse transformation for column {col}: {e}")
536
+
537
+ return data_inverse
538
+
539
+ def _plot_scaling_results(
540
+ self,
541
+ original_data: pd.DataFrame,
542
+ scaled_data: pd.DataFrame,
543
+ columns: List[str],
544
+ method: str
545
+ ) -> None:
546
+ """Visualise scaling results"""
547
+ # Limit number of columns for visualisation
548
+ cols_to_plot = [col for col in columns if col in original_data.columns and col in scaled_data.columns][:8]
549
+
550
+ if not cols_to_plot:
551
+ return
552
+
553
+ n_cols = 4
554
+ n_rows = (len(cols_to_plot) + n_cols - 1) // n_cols
555
+
556
+ fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(16, 4 * n_rows))
557
+
558
+ for idx, col in enumerate(cols_to_plot):
559
+ row = idx // n_cols
560
+ col_idx = (idx % n_cols) * 2
561
+
562
+ # Distribution before scaling
563
+ axes[row, col_idx].hist(
564
+ original_data[col].dropna(),
565
+ bins=30,
566
+ alpha=0.7,
567
+ color='blue',
568
+ density=True
569
+ )
570
+ axes[row, col_idx].set_title(f'{col} (before)', fontsize=10)
571
+ axes[row, col_idx].set_xlabel('Value')
572
+ axes[row, col_idx].set_ylabel('Density')
573
+ axes[row, col_idx].grid(True, alpha=0.3)
574
+
575
+ # Distribution after scaling
576
+ axes[row, col_idx + 1].hist(
577
+ scaled_data[col].dropna(),
578
+ bins=30,
579
+ alpha=0.7,
580
+ color='green',
581
+ density=True
582
+ )
583
+ axes[row, col_idx + 1].set_title(f'{col} (after)', fontsize=10)
584
+ axes[row, col_idx + 1].set_xlabel('Scaled value')
585
+ axes[row, col_idx + 1].set_ylabel('Density')
586
+ axes[row, col_idx + 1].grid(True, alpha=0.3)
587
+
588
+ # Hide unused subplots
589
+ total_plots = n_rows * n_cols * 2
590
+ for idx in range(len(cols_to_plot) * 2, total_plots):
591
+ row = idx // (n_cols * 2)
592
+ col_idx = idx % (n_cols * 2)
593
+ axes[row, col_idx].set_visible(False)
594
+
595
+ plt.suptitle(f'Scaling results using {method} method', fontsize=14)
596
+ plt.tight_layout()
597
+ plt.savefig(
598
+ f'{self.config.results_dir}/plots/scaling_results.png',
599
+ dpi=300,
600
+ bbox_inches='tight'
601
+ )
602
+ plt.show()
603
+
604
+ def get_report(self) -> Dict:
605
+ """Get scaling report"""
606
+ summary = {
607
+ 'total_columns': len(self.scaling_info),
608
+ 'methods_used': {},
609
+ 'binary_features': [],
610
+ 'categorical_features': [],
611
+ 'continuous_features': [],
612
+ 'errors': []
613
+ }
614
+
615
+ for col, info in self.scaling_info.items():
616
+ method = info.get('method', 'unknown')
617
+ if method not in summary['methods_used']:
618
+ summary['methods_used'][method] = 0
619
+ summary['methods_used'][method] += 1
620
+
621
+ if method == 'none' and info.get('scaler_type') == 'binary':
622
+ summary['binary_features'].append(col)
623
+ elif method in ['minmax', 'standard', 'robust']:
624
+ summary['continuous_features'].append(col)
625
+ elif method == 'error':
626
+ summary['errors'].append({
627
+ 'column': col,
628
+ 'error': info.get('error', 'unknown')
629
+ })
630
+
631
+ return {
632
+ 'summary': summary,
633
+ 'details': self.scaling_info
634
+ }
splitting/__init__.py ADDED
File without changes
splitting/data_splitter.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 9: DATA SPLITTING
3
+ # ============================================
4
+ from datetime import datetime
5
+ from typing import Dict, Optional, Tuple
6
+ from venv import logger
7
+ import pandas as pd
8
+ from config.config import Config
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ class DataSplitter:
14
+ """Class for splitting data into train, validation and test sets"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise data splitter
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.split_info = {}
27
+ self.split_indices = {}
28
+ self.split_strategy = None
29
+
30
+ def split(
31
+ self,
32
+ data: pd.DataFrame,
33
+ test_size: Optional[float] = None,
34
+ validation_size: Optional[float] = None,
35
+ method: str = None,
36
+ random_state: int = 42,
37
+ **kwargs
38
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
39
+ """
40
+ Split data into train, validation and test sets
41
+
42
+ Parameters:
43
+ -----------
44
+ data : pd.DataFrame
45
+ Input data
46
+ test_size : float, optional
47
+ Test set size. If None, uses configuration value.
48
+ validation_size : float, optional
49
+ Validation set size. If None, uses configuration value.
50
+ method : str, optional
51
+ Splitting method: 'time', 'random', 'expanding_window', 'sliding_window'
52
+ random_state : int
53
+ Seed for reproducibility
54
+ **kwargs : dict
55
+ Additional parameters for method
56
+
57
+ Returns:
58
+ --------
59
+ Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
60
+ Train, validation and test data
61
+ """
62
+ logger.info("\n" + "="*80)
63
+ logger.info("DATA SPLITTING")
64
+ logger.info("="*80)
65
+
66
+ test_size = test_size or self.config.test_size
67
+ validation_size = validation_size or self.config.validation_size
68
+ method = method or self.config.split_method
69
+
70
+ n = len(data)
71
+
72
+ logger.info(f"Total data: {n} records")
73
+ logger.info(f"Splitting method: {method}")
74
+ logger.info(f"Sizes: train={1-test_size-validation_size:.1%}, val={validation_size:.1%}, test={test_size:.1%}")
75
+
76
+ if method == 'time':
77
+ train_data, val_data, test_data = self._time_based_split(
78
+ data, test_size, validation_size
79
+ )
80
+ elif method == 'random':
81
+ train_data, val_data, test_data = self._random_split(
82
+ data, test_size, validation_size, random_state
83
+ )
84
+ elif method == 'expanding_window':
85
+ train_data, val_data, test_data = self._expanding_window_split(
86
+ data, test_size, validation_size, **kwargs
87
+ )
88
+ elif method == 'sliding_window':
89
+ train_data, val_data, test_data = self._sliding_window_split(
90
+ data, **kwargs
91
+ )
92
+ else:
93
+ logger.warning(f"Method {method} not supported, using time-based split")
94
+ train_data, val_data, test_data = self._time_based_split(
95
+ data, test_size, validation_size
96
+ )
97
+
98
+ # Save splitting information
99
+ self._save_split_info(data, train_data, val_data, test_data, method)
100
+
101
+ # Output information
102
+ self._log_split_summary(train_data, val_data, test_data)
103
+
104
+ # Visualisation of split
105
+ if self.config.save_plots:
106
+ self._plot_data_split(data, train_data, val_data, test_data)
107
+
108
+ return train_data, val_data, test_data
109
+
110
+ def _time_based_split(
111
+ self,
112
+ data: pd.DataFrame,
113
+ test_size: float,
114
+ validation_size: float
115
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
116
+ """Time-based splitting preserving temporal order"""
117
+ n = len(data)
118
+
119
+ # Calculate set sizes
120
+ test_size_int = int(n * test_size)
121
+ val_size_int = int(n * validation_size)
122
+ train_size_int = n - test_size_int - val_size_int
123
+
124
+ # Split data
125
+ train_data = data.iloc[:train_size_int].copy()
126
+ val_data = data.iloc[train_size_int:train_size_int + val_size_int].copy()
127
+ test_data = data.iloc[train_size_int + val_size_int:].copy()
128
+
129
+ self.split_strategy = 'time_based'
130
+
131
+ return train_data, val_data, test_data
132
+
133
+ def _random_split(
134
+ self,
135
+ data: pd.DataFrame,
136
+ test_size: float,
137
+ validation_size: float,
138
+ random_state: int
139
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
140
+ """Random data splitting"""
141
+ from sklearn.model_selection import train_test_split
142
+
143
+ # First split into train+val and test
144
+ train_val_data, test_data = train_test_split(
145
+ data,
146
+ test_size=test_size,
147
+ random_state=random_state,
148
+ shuffle=True
149
+ )
150
+
151
+ # Then split train+val into train and val
152
+ val_relative_size = validation_size / (1 - test_size)
153
+ train_data, val_data = train_test_split(
154
+ train_val_data,
155
+ test_size=val_relative_size,
156
+ random_state=random_state,
157
+ shuffle=True
158
+ )
159
+
160
+ self.split_strategy = 'random'
161
+
162
+ return train_data, val_data, test_data
163
+
164
+ def _expanding_window_split(
165
+ self,
166
+ data: pd.DataFrame,
167
+ test_size: float,
168
+ validation_size: float,
169
+ **kwargs
170
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
171
+ """Expanding window split"""
172
+ n = len(data)
173
+
174
+ # Minimum initial window size
175
+ initial_window = kwargs.get('initial_window', max(100, int(n * 0.1)))
176
+
177
+ # Final set sizes
178
+ test_size_int = int(n * test_size)
179
+ val_size_int = int(n * validation_size)
180
+
181
+ # Determine boundaries
182
+ test_start = n - test_size_int
183
+ val_start = test_start - val_size_int
184
+
185
+ # For expanding window, use all data up to val_start for training
186
+ train_data = data.iloc[:val_start].copy()
187
+ val_data = data.iloc[val_start:test_start].copy()
188
+ test_data = data.iloc[test_start:].copy()
189
+
190
+ self.split_strategy = 'expanding_window'
191
+
192
+ return train_data, val_data, test_data
193
+
194
+ def _sliding_window_split(
195
+ self,
196
+ data: pd.DataFrame,
197
+ **kwargs
198
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
199
+ """Sliding window split (for multiple train-val-test pairs)"""
200
+ window_size = kwargs.get('window_size', len(data) // 3)
201
+ step = kwargs.get('step', window_size // 2)
202
+
203
+ # For simplicity return single split
204
+ # In real scenarios can return list of splits
205
+ n = len(data)
206
+
207
+ train_end = n - window_size
208
+ val_end = train_end + window_size // 3
209
+ test_end = n
210
+
211
+ train_data = data.iloc[:train_end].copy()
212
+ val_data = data.iloc[train_end:val_end].copy()
213
+ test_data = data.iloc[val_end:].copy()
214
+
215
+ self.split_strategy = 'sliding_window'
216
+
217
+ return train_data, val_data, test_data
218
+
219
+ def _save_split_info(
220
+ self,
221
+ full_data: pd.DataFrame,
222
+ train_data: pd.DataFrame,
223
+ val_data: pd.DataFrame,
224
+ test_data: pd.DataFrame,
225
+ method: str
226
+ ) -> None:
227
+ """Save splitting information"""
228
+ n = len(full_data)
229
+
230
+ self.split_info = {
231
+ 'method': method,
232
+ 'strategy': self.split_strategy,
233
+ 'train_size': len(train_data),
234
+ 'val_size': len(val_data),
235
+ 'test_size': len(test_data),
236
+ 'train_percent': len(train_data) / n * 100,
237
+ 'val_percent': len(val_data) / n * 100,
238
+ 'test_percent': len(test_data) / n * 100,
239
+ 'total_samples': n,
240
+ 'features_count': len(full_data.columns),
241
+ 'split_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
242
+ }
243
+
244
+ # Add temporal period information if available
245
+ if isinstance(full_data.index, pd.DatetimeIndex):
246
+ self.split_info.update({
247
+ 'train_period': {
248
+ 'start': train_data.index.min().strftime('%Y-%m-%d'),
249
+ 'end': train_data.index.max().strftime('%Y-%m-%d')
250
+ },
251
+ 'val_period': {
252
+ 'start': val_data.index.min().strftime('%Y-%m-%d'),
253
+ 'end': val_data.index.max().strftime('%Y-%m-%d')
254
+ },
255
+ 'test_period': {
256
+ 'start': test_data.index.min().strftime('%Y-%m-%d'),
257
+ 'end': test_data.index.max().strftime('%Y-%m-%d')
258
+ }
259
+ })
260
+
261
+ # Save split indices
262
+ self.split_indices = {
263
+ 'train': train_data.index.tolist(),
264
+ 'val': val_data.index.tolist(),
265
+ 'test': test_data.index.tolist()
266
+ }
267
+
268
+ def _log_split_summary(
269
+ self,
270
+ train_data: pd.DataFrame,
271
+ val_data: pd.DataFrame,
272
+ test_data: pd.DataFrame
273
+ ) -> None:
274
+ """Log splitting summary"""
275
+ logger.info("✓ Data split completed:")
276
+ logger.info(f" Train: {len(train_data)} records ({self.split_info['train_percent']:.1f}%)")
277
+ logger.info(f" Validation: {len(val_data)} records ({self.split_info['val_percent']:.1f}%)")
278
+ logger.info(f" Test: {len(test_data)} records ({self.split_info['test_percent']:.1f}%)")
279
+
280
+ if 'train_period' in self.split_info:
281
+ logger.info(f"\nPeriods:")
282
+ logger.info(f" Train: {self.split_info['train_period']['start']} - {self.split_info['train_period']['end']}")
283
+ logger.info(f" Validation: {self.split_info['val_period']['start']} - {self.split_info['val_period']['end']}")
284
+ logger.info(f" Test: {self.split_info['test_period']['start']} - {self.split_info['test_period']['end']}")
285
+
286
+ # Target variable statistics
287
+ target = self.config.target_column
288
+ if target in train_data.columns:
289
+ logger.info(f"\nTarget variable '{target}' statistics:")
290
+ logger.info(f" Train: mean={train_data[target].mean():.2f}, std={train_data[target].std():.2f}")
291
+ logger.info(f" Validation: mean={val_data[target].mean():.2f}, std={val_data[target].std():.2f}")
292
+ logger.info(f" Test: mean={test_data[target].mean():.2f}, std={test_data[target].std():.2f}")
293
+
294
+ def _plot_data_split(
295
+ self,
296
+ full_data: pd.DataFrame,
297
+ train_data: pd.DataFrame,
298
+ val_data: pd.DataFrame,
299
+ test_data: pd.DataFrame
300
+ ) -> None:
301
+ """Visualise data splitting"""
302
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
303
+
304
+ target = self.config.target_column
305
+
306
+ # 1. Time series with set highlighting
307
+ if target in full_data.columns and isinstance(full_data.index, pd.DatetimeIndex):
308
+ axes[0, 0].plot(train_data.index, train_data[target],
309
+ label='Train', colour='blue', alpha=0.7, linewidth=1)
310
+ axes[0, 0].plot(val_data.index, val_data[target],
311
+ label='Validation', colour='orange', alpha=0.7, linewidth=1)
312
+ axes[0, 0].plot(test_data.index, test_data[target],
313
+ label='Test', colour='red', alpha=0.7, linewidth=1)
314
+
315
+ axes[0, 0].set_title(f'Data Split: {target}')
316
+ axes[0, 0].set_xlabel('Date')
317
+ axes[0, 0].set_ylabel(target)
318
+ axes[0, 0].legend()
319
+ axes[0, 0].grid(True, alpha=0.3)
320
+
321
+ # 2. Yearly distribution
322
+ if isinstance(full_data.index, pd.DatetimeIndex):
323
+ full_data['year'] = full_data.index.year
324
+ train_data['year'] = train_data.index.year
325
+ val_data['year'] = val_data.index.year
326
+ test_data['year'] = test_data.index.year
327
+
328
+ years = sorted(full_data['year'].unique())
329
+ train_counts = [len(train_data[train_data['year'] == year]) for year in years]
330
+ val_counts = [len(val_data[val_data['year'] == year]) for year in years]
331
+ test_counts = [len(test_data[test_data['year'] == year]) for year in years]
332
+
333
+ x = np.arange(len(years))
334
+ width = 0.25
335
+
336
+ axes[0, 1].bar(x - width, train_counts, width, label='Train', colour='blue', alpha=0.7)
337
+ axes[0, 1].bar(x, val_counts, width, label='Validation', colour='orange', alpha=0.7)
338
+ axes[0, 1].bar(x + width, test_counts, width, label='Test', colour='red', alpha=0.7)
339
+
340
+ axes[0, 1].set_title('Yearly Data Distribution')
341
+ axes[0, 1].set_xlabel('Year')
342
+ axes[0, 1].set_ylabel('Number of Records')
343
+ axes[0, 1].set_xticks(x)
344
+ axes[0, 1].set_xticklabels(years, rotation=45)
345
+ axes[0, 1].legend()
346
+ axes[0, 1].grid(True, alpha=0.3)
347
+
348
+ # Remove added columns
349
+ for df in [full_data, train_data, val_data, test_data]:
350
+ if 'year' in df.columns:
351
+ df.drop('year', axis=1, inplace=True)
352
+
353
+ # 3. Target variable distribution
354
+ if target in full_data.columns:
355
+ axes[1, 0].hist(train_data[target].dropna(), bins=30, alpha=0.5, label='Train', density=True)
356
+ axes[1, 0].hist(val_data[target].dropna(), bins=30, alpha=0.5, label='Validation', density=True)
357
+ axes[1, 0].hist(test_data[target].dropna(), bins=30, alpha=0.5, label='Test', density=True)
358
+
359
+ axes[1, 0].set_title(f'{target} Distribution Across Sets')
360
+ axes[1, 0].set_xlabel(target)
361
+ axes[1, 0].set_ylabel('Density')
362
+ axes[1, 0].legend()
363
+ axes[1, 0].grid(True, alpha=0.3)
364
+
365
+ # 4. Set statistics
366
+ if target in full_data.columns:
367
+ stats_data = []
368
+ for name, df in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
369
+ if target in df.columns:
370
+ stats_data.append({
371
+ 'Dataset': name,
372
+ 'Mean': df[target].mean(),
373
+ 'Std': df[target].std(),
374
+ 'Min': df[target].min(),
375
+ 'Max': df[target].max()
376
+ })
377
+
378
+ if stats_data:
379
+ stats_df = pd.DataFrame(stats_data)
380
+ stats_table = axes[1, 1].table(
381
+ cellText=stats_df.round(2).values,
382
+ colLabels=stats_df.columns,
383
+ cellLoc='center',
384
+ loc='center'
385
+ )
386
+ stats_table.auto_set_font_size(False)
387
+ stats_table.set_fontsize(9)
388
+ stats_table.scale(1, 1.5)
389
+ axes[1, 1].axis('off')
390
+ axes[1, 1].set_title('Set Statistics')
391
+
392
+ plt.suptitle(f'Data Splitting: {self.split_info["method"]} method', fontsize=14)
393
+ plt.tight_layout()
394
+ plt.savefig(
395
+ f'{self.config.results_dir}/plots/data_split.png',
396
+ dpi=300,
397
+ bbox_inches='tight'
398
+ )
399
+ plt.show()
400
+
401
+ def get_report(self) -> Dict:
402
+ """Get data splitting report"""
403
+ return self.split_info
stationarity/__init__.py ADDED
File without changes
stationarity/stationarity_checker.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 6: STATIONARITY ANALYSIS
3
+ # ============================================
4
+ from typing import Dict, Optional
5
+ from venv import logger
6
+ from config.config import Config
7
+ import pandas as pd
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
11
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
12
+
13
+ class StationarityChecker:
14
+ """Class for checking time series stationarity"""
15
+
16
+ def __init__(self, config: Config):
17
+ """
18
+ Initialise stationarity checker
19
+
20
+ Parameters:
21
+ -----------
22
+ config : Config
23
+ Experiment configuration
24
+ """
25
+ self.config = config
26
+ self.test_results = {}
27
+ self.transformed_series = {}
28
+ self.best_transformation = {}
29
+
30
+ def check(
31
+ self,
32
+ data: pd.DataFrame,
33
+ target_col: Optional[str] = None,
34
+ make_stationary: bool = True,
35
+ try_transformations: bool = True
36
+ ) -> Dict:
37
+ """
38
+ Check time series stationarity
39
+
40
+ Parameters:
41
+ -----------
42
+ data : pd.DataFrame
43
+ Input data
44
+ target_col : str, optional
45
+ Target variable. If None, uses configuration default.
46
+ make_stationary : bool
47
+ Transform series to stationary form
48
+ try_transformations : bool
49
+ Try various transformations to achieve stationarity
50
+
51
+ Returns:
52
+ --------
53
+ Dict
54
+ Stationarity test results
55
+ """
56
+ logger.info("\n" + "="*80)
57
+ logger.info("STATIONARITY ANALYSIS")
58
+ logger.info("="*80)
59
+
60
+ target_col = target_col or self.config.target_column
61
+
62
+ if target_col not in data.columns:
63
+ logger.error(f"Target variable '{target_col}' not found")
64
+ return {}
65
+
66
+ series = data[target_col].dropna()
67
+
68
+ if len(series) < 10:
69
+ logger.warning("Insufficient data for stationarity analysis")
70
+ return {}
71
+
72
+ # Perform analysis
73
+ results = self._perform_stationarity_tests(series, target_col)
74
+
75
+ # Save results
76
+ self.test_results[target_col] = results
77
+
78
+ # Visualisation
79
+ if self.config.save_plots:
80
+ self._plot_stationarity_analysis(data, target_col, results)
81
+
82
+ # Log results
83
+ self._log_test_results(target_col, results)
84
+
85
+ # Transform to stationary form
86
+ if make_stationary and not results['overall']['is_stationary']:
87
+ if try_transformations:
88
+ transformed_data = self._make_stationary(data, target_col, results)
89
+ if transformed_data is not None:
90
+ data = transformed_data
91
+
92
+ return results
93
+
94
+
95
+ def _perform_stationarity_tests(
96
+ self,
97
+ series: pd.Series,
98
+ target_col: str
99
+ ) -> Dict:
100
+ """Perform various stationarity tests"""
101
+ results = {
102
+ 'adf': self._adf_test(series),
103
+ 'kpss': self._kpss_test(series),
104
+ 'pp': self._pp_test(series),
105
+ 'hurst': self._hurst_exponent(series),
106
+ 'variance_ratio': self._variance_ratio_test(series),
107
+ 'overall': {}
108
+ }
109
+
110
+ # Determine overall stationarity
111
+ adf_stationary = results['adf'].get('is_stationary', False)
112
+ kpss_stationary = results['kpss'].get('is_stationary', False)
113
+ pp_stationary = results['pp'].get('is_stationary', False)
114
+
115
+ # Stationarity determination logic
116
+ if adf_stationary and kpss_stationary:
117
+ overall_stationary = True
118
+ confidence = 'high'
119
+ elif adf_stationary and not kpss_stationary:
120
+ overall_stationary = True # ADF more reliable for detecting stationarity
121
+ confidence = 'medium'
122
+ elif not adf_stationary and kpss_stationary:
123
+ overall_stationary = False # KPSS indicates non-stationarity
124
+ confidence = 'medium'
125
+ else:
126
+ overall_stationary = False
127
+ confidence = 'high'
128
+
129
+ results['overall'] = {
130
+ 'is_stationary': overall_stationary,
131
+ 'confidence': confidence,
132
+ 'recommendation': self._get_stationarity_recommendation(results)
133
+ }
134
+
135
+ return results
136
+
137
+ def _adf_test(self, series: pd.Series) -> Dict:
138
+ """Augmented Dickey-Fuller (ADF) test"""
139
+ try:
140
+ adf_result = adfuller(series, autolag='AIC')
141
+
142
+ return {
143
+ 'statistic': float(adf_result[0]),
144
+ 'pvalue': float(adf_result[1]),
145
+ 'critical_values': {k: float(v) for k, v in adf_result[4].items()},
146
+ 'is_stationary': adf_result[1] < 0.05,
147
+ 'used_lag': int(adf_result[2]),
148
+ 'nobs': int(adf_result[3])
149
+ }
150
+ except Exception as e:
151
+ logger.warning(f"ADF test failed: {e}")
152
+ return {
153
+ 'statistic': np.nan,
154
+ 'pvalue': np.nan,
155
+ 'critical_values': {},
156
+ 'is_stationary': False,
157
+ 'error': str(e)
158
+ }
159
+
160
+ def _kpss_test(self, series: pd.Series) -> Dict:
161
+ """KPSS test"""
162
+ try:
163
+ kpss_result = kpss(series, regression='c', nlags='auto')
164
+
165
+ return {
166
+ 'statistic': float(kpss_result[0]),
167
+ 'pvalue': float(kpss_result[1]),
168
+ 'critical_values': {k: float(v) for k, v in kpss_result[3].items()},
169
+ 'is_stationary': kpss_result[1] > 0.05, # KPSS: p > 0.05 indicates stationarity
170
+ 'used_lag': int(kpss_result[2])
171
+ }
172
+ except Exception as e:
173
+ logger.warning(f"KPSS test failed: {e}")
174
+ return {
175
+ 'statistic': np.nan,
176
+ 'pvalue': np.nan,
177
+ 'critical_values': {},
178
+ 'is_stationary': False,
179
+ 'error': str(e)
180
+ }
181
+
182
+ def _pp_test(self, series: pd.Series) -> Dict:
183
+ """Phillips-Perron test"""
184
+ try:
185
+ # Simplified PP test version
186
+ from statsmodels.tsa.stattools import PhillipsPerron
187
+
188
+ pp_result = PhillipsPerron(series)
189
+
190
+ return {
191
+ 'statistic': float(pp_result.stat),
192
+ 'pvalue': float(pp_result.pvalue),
193
+ 'critical_values': pp_result.critical_values,
194
+ 'is_stationary': pp_result.pvalue < 0.05
195
+ }
196
+ except:
197
+ # If statsmodels with PP test not available
198
+ return {
199
+ 'statistic': np.nan,
200
+ 'pvalue': np.nan,
201
+ 'critical_values': {},
202
+ 'is_stationary': False,
203
+ 'note': 'Phillips-Perron test not available'
204
+ }
205
+
206
+ def _hurst_exponent(self, series: pd.Series) -> Dict:
207
+ """Calculate Hurst exponent"""
208
+ try:
209
+ # Simplified Hurst exponent calculation
210
+ lags = range(2, min(100, len(series)//4))
211
+ tau = []
212
+
213
+ for lag in lags:
214
+ # Split series into subsequences of length lag
215
+ n = len(series) // lag
216
+ if n < 2:
217
+ continue
218
+
219
+ subseries = [series[i*lag:(i+1)*lag] for i in range(n)]
220
+ # Calculate R/S for each subsequence
221
+ rs_values = []
222
+ for sub in subseries:
223
+ if len(sub) > 1:
224
+ mean = np.mean(sub)
225
+ deviations = sub - mean
226
+ z = np.cumsum(deviations)
227
+ r = np.max(z) - np.min(z)
228
+ s = np.std(sub)
229
+ if s > 0:
230
+ rs_values.append(r / s)
231
+
232
+ if rs_values:
233
+ tau.append(np.mean(rs_values))
234
+
235
+ if len(tau) > 2:
236
+ # Linear regression in log coordinates
237
+ x = np.log(lags[:len(tau)])
238
+ y = np.log(tau)
239
+
240
+ if len(x) > 1 and len(y) > 1:
241
+ slope = np.polyfit(x, y, 1)[0]
242
+
243
+ # Hurst exponent interpretation
244
+ if slope > 0.5:
245
+ trend_type = 'persistent'
246
+ elif slope < 0.5:
247
+ trend_type = 'anti-persistent'
248
+ else:
249
+ trend_type = 'random'
250
+
251
+ return {
252
+ 'exponent': float(slope),
253
+ 'trend_type': trend_type,
254
+ 'interpretation': self._interpret_hurst(slope)
255
+ }
256
+
257
+ return {
258
+ 'exponent': np.nan,
259
+ 'trend_type': 'unknown',
260
+ 'interpretation': 'Insufficient data'
261
+ }
262
+
263
+ except Exception as e:
264
+ logger.debug(f"Hurst exponent not calculated: {e}")
265
+ return {
266
+ 'exponent': np.nan,
267
+ 'trend_type': 'unknown',
268
+ 'error': str(e)
269
+ }
270
+
271
+ def _interpret_hurst(self, hurst_exponent: float) -> str:
272
+ """Interpret Hurst exponent"""
273
+ if hurst_exponent > 0.75:
274
+ return "Strong persistence (long-term memory)"
275
+ elif hurst_exponent > 0.6:
276
+ return "Moderate persistence"
277
+ elif hurst_exponent > 0.4:
278
+ return "Weak persistence / random walk"
279
+ elif hurst_exponent > 0.25:
280
+ return "Weak anti-persistence"
281
+ else:
282
+ return "Strong anti-persistence (frequent trend reversal)"
283
+
284
+ def _variance_ratio_test(self, series: pd.Series) -> Dict:
285
+ """Variance Ratio test for random walk"""
286
+ try:
287
+ # Simplified variance ratio test
288
+ if len(series) < 20:
289
+ return {'ratio': np.nan, 'is_random_walk': False}
290
+
291
+ # Calculate differences
292
+ diff1 = series.diff(1).dropna()
293
+ diff2 = series.diff(2).dropna()[1:] # Shift to align indices
294
+
295
+ if len(diff1) < 5 or len(diff2) < 5:
296
+ return {'ratio': np.nan, 'is_random_walk': False}
297
+
298
+ var1 = np.var(diff1)
299
+ var2 = np.var(diff2)
300
+
301
+ if var1 > 0:
302
+ ratio = var2 / (2 * var1)
303
+
304
+ # For random walk ratio ≈ 1
305
+ is_random_walk = 0.8 < ratio < 1.2
306
+
307
+ return {
308
+ 'ratio': float(ratio),
309
+ 'is_random_walk': bool(is_random_walk),
310
+ 'var_diff1': float(var1),
311
+ 'var_diff2': float(var2)
312
+ }
313
+ else:
314
+ return {'ratio': np.nan, 'is_random_walk': False}
315
+
316
+ except Exception as e:
317
+ logger.debug(f"Variance ratio test failed: {e}")
318
+ return {'ratio': np.nan, 'is_random_walk': False, 'error': str(e)}
319
+
320
+ def _get_stationarity_recommendation(self, results: Dict) -> str:
321
+ """Get stationarity recommendations"""
322
+ # Check for keys before access
323
+ if 'overall' not in results or 'is_stationary' not in results['overall']:
324
+ return "Could not determine stationarity. Check data and test settings."
325
+
326
+ if results['overall']['is_stationary']:
327
+ return "Series is stationary, suitable for modelling"
328
+ else:
329
+ recommendations = []
330
+
331
+ # Check Hurst test results
332
+ if 'hurst' in results and 'exponent' in results['hurst']:
333
+ hurst_exponent = results['hurst']['exponent']
334
+ if not np.isnan(hurst_exponent) and hurst_exponent > 0.6:
335
+ recommendations.append("Apply differencing to remove trend")
336
+
337
+ # Check ADF test
338
+ if 'adf' in results and 'pvalue' in results['adf']:
339
+ adf_pvalue = results['adf']['pvalue']
340
+ if not np.isnan(adf_pvalue) and adf_pvalue > 0.1:
341
+ recommendations.append("Consider seasonal differencing due to non-stationarity")
342
+
343
+ if len(recommendations) == 0:
344
+ recommendations.append("Try logarithmic transformation and differencing")
345
+
346
+ return "; ".join(recommendations)
347
+
348
+ def _plot_stationarity_analysis(
349
+ self,
350
+ data: pd.DataFrame,
351
+ target_col: str,
352
+ results: Dict
353
+ ) -> None:
354
+ """Visualise stationarity analysis"""
355
+ series = data[target_col]
356
+
357
+ fig, axes = plt.subplots(2, 3, figsize=(16, 10))
358
+
359
+ # 1. Original series
360
+ axes[0, 0].plot(series.index, series, linewidth=1)
361
+ axes[0, 0].set_title(f'Original Time Series: {target_col}')
362
+ axes[0, 0].set_xlabel('Date')
363
+ axes[0, 0].set_ylabel(target_col)
364
+ axes[0, 0].grid(True, alpha=0.3)
365
+
366
+ # 2. Rolling statistics
367
+ rolling_mean = series.rolling(window=365, center=True, min_periods=1).mean()
368
+ rolling_std = series.rolling(window=365, center=True, min_periods=1).std()
369
+
370
+ axes[0, 1].plot(series.index, series, label='Original series', alpha=0.7, linewidth=0.5)
371
+ axes[0, 1].plot(rolling_mean.index, rolling_mean, label='Rolling mean (365)', color='red', linewidth=2)
372
+ axes[0, 1].plot(rolling_std.index, rolling_std, label='Rolling STD (365)', color='green', linewidth=2)
373
+ axes[0, 1].set_title(f'Rolling Statistics: {target_col}')
374
+ axes[0, 1].set_xlabel('Date')
375
+ axes[0, 1].set_ylabel(target_col)
376
+ axes[0, 1].legend(fontsize=8)
377
+ axes[0, 1].grid(True, alpha=0.3)
378
+
379
+ # 3. ACF
380
+ plot_acf(series.dropna(), lags=50, ax=axes[0, 2], alpha=0.05)
381
+ axes[0, 2].set_title(f'Autocorrelation Function (ACF): {target_col}')
382
+ axes[0, 2].set_xlabel('Lag')
383
+ axes[0, 2].set_ylabel('Autocorrelation')
384
+ axes[0, 2].grid(True, alpha=0.3)
385
+
386
+ # 4. PACF
387
+ plot_pacf(series.dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
388
+ axes[1, 0].set_title(f'Partial Autocorrelation Function (PACF): {target_col}')
389
+ axes[1, 0].set_xlabel('Lag')
390
+ axes[1, 0].set_ylabel('Partial Autocorrelation')
391
+ axes[1, 0].grid(True, alpha=0.3)
392
+
393
+ # 5. Histogram and Q-Q plot
394
+ axes[1, 1].hist(series.dropna(), bins=30, edgecolor='black', alpha=0.7, density=True)
395
+ axes[1, 1].set_title(f'Distribution: {target_col}')
396
+ axes[1, 1].set_xlabel('Value')
397
+ axes[1, 1].set_ylabel('Density')
398
+ axes[1, 1].grid(True, alpha=0.3)
399
+
400
+ # 6. Series differences
401
+ diff1 = series.diff(1).dropna()
402
+ axes[1, 2].plot(diff1.index, diff1, linewidth=0.5)
403
+ axes[1, 2].set_title(f'First Difference: {target_col}')
404
+ axes[1, 2].set_xlabel('Date')
405
+ axes[1, 2].set_ylabel(f'Δ{target_col}')
406
+ axes[1, 2].grid(True, alpha=0.3)
407
+
408
+ plt.suptitle(
409
+ f'Stationarity Analysis: {target_col}\n'
410
+ f'Stationary: {"✓ Yes" if results["overall"]["is_stationary"] else "✗ No"} '
411
+ f'(confidence: {results["overall"]["confidence"]})',
412
+ fontsize=14
413
+ )
414
+
415
+ plt.tight_layout()
416
+ plt.savefig(
417
+ f'{self.config.results_dir}/plots/stationarity_{target_col}.png',
418
+ dpi=300,
419
+ bbox_inches='tight'
420
+ )
421
+ plt.show()
422
+
423
+ def _log_test_results(self, target_col: str, results: Dict) -> None:
424
+ """Log test results"""
425
+ logger.info("\nSTATIONARITY TEST RESULTS:")
426
+ logger.info("-" * 50)
427
+
428
+ # ADF test
429
+ adf = results['adf']
430
+ logger.info(f"Augmented Dickey-Fuller (ADF) test:")
431
+ logger.info(f" Statistic: {adf['statistic']:.4f}")
432
+ logger.info(f" p-value: {adf['pvalue']:.4f}")
433
+ logger.info(f" Stationary: {'✓ Yes' if adf['is_stationary'] else '✗ No'}")
434
+
435
+ # KPSS test
436
+ kpss_test = results['kpss']
437
+ if 'statistic' in kpss_test and not np.isnan(kpss_test['statistic']):
438
+ logger.info(f"\nKPSS test:")
439
+ logger.info(f" Statistic: {kpss_test['statistic']:.4f}")
440
+ logger.info(f" p-value: {kpss_test['pvalue']:.4f}")
441
+ logger.info(f" Stationary: {'✓ Yes' if kpss_test['is_stationary'] else '✗ No'}")
442
+
443
+ # Hurst exponent
444
+ hurst = results['hurst']
445
+ if 'exponent' in hurst and not np.isnan(hurst['exponent']):
446
+ logger.info(f"\nHurst exponent:")
447
+ logger.info(f" Value: {hurst['exponent']:.3f}")
448
+ logger.info(f" Trend type: {hurst['trend_type']}")
449
+ logger.info(f" Interpretation: {hurst.get('interpretation', '')}")
450
+
451
+ # Overall interpretation
452
+ logger.info(f"\nOVERALL CONCLUSION:")
453
+ logger.info("-" * 30)
454
+ logger.info(f"Stationary: {'✓ Yes' if results['overall']['is_stationary'] else '✗ No'}")
455
+ logger.info(f"Confidence: {results['overall']['confidence']}")
456
+ logger.info(f"Recommendation: {results['overall']['recommendation']}")
457
+
458
+ def _make_stationary(
459
+ self,
460
+ data: pd.DataFrame,
461
+ target_col: str,
462
+ results: Dict
463
+ ) -> Optional[pd.DataFrame]:
464
+ """
465
+ Transform series to stationary form
466
+
467
+ Parameters:
468
+ -----------
469
+ data : pd.DataFrame
470
+ Input data
471
+ target_col : str
472
+ Target variable
473
+ results : Dict
474
+ Stationarity test results
475
+
476
+ Returns:
477
+ --------
478
+ Optional[pd.DataFrame]
479
+ Data with stationary series or None if transformation failed
480
+ """
481
+ logger.info("\nTRANSFORMING TO STATIONARY FORM:")
482
+ logger.info("-" * 40)
483
+
484
+ data_processed = data.copy()
485
+ series = data_processed[target_col]
486
+
487
+ # Stationarisation methods in order of preference
488
+ methods = [
489
+ ('diff', 'first-order differencing'),
490
+ ('seasonal_diff', f'seasonal differencing (period={self.config.seasonal_period})'),
491
+ ('log_diff', 'logarithmic differencing'),
492
+ ('boxcox_diff', 'Box-Cox + differencing'),
493
+ ('detrend', 'detrending'),
494
+ ('combination', 'combined method')
495
+ ]
496
+
497
+ best_method = None
498
+ best_series = None
499
+ best_pvalue = 1.0
500
+ best_stationary = False
501
+
502
+ for method, method_name in methods:
503
+ try:
504
+ if method == 'diff':
505
+ # Simple differencing
506
+ transformed = series.diff(1).dropna()
507
+ test_series = transformed
508
+
509
+ elif method == 'seasonal_diff':
510
+ # Seasonal differencing
511
+ transformed = series.diff(self.config.seasonal_period).dropna()
512
+ test_series = transformed
513
+
514
+ elif method == 'log_diff':
515
+ # Logarithmic differencing
516
+ if (series > 0).all():
517
+ log_series = np.log(series)
518
+ transformed = log_series.diff(1).dropna()
519
+ test_series = transformed
520
+ else:
521
+ # Shift for negative values
522
+ shift = abs(series.min()) + 1 if series.min() <= 0 else 0
523
+ log_series = np.log(series + shift)
524
+ transformed = log_series.diff(1).dropna()
525
+ test_series = transformed
526
+
527
+ elif method == 'boxcox_diff':
528
+ # Box-Cox transformation + differencing
529
+ try:
530
+ from scipy.stats import boxcox
531
+ # Add constant for positive values
532
+ shift = abs(series.min()) + 1 if series.min() <= 0 else 0
533
+ boxcox_series, _ = boxcox(series + shift)
534
+ transformed = pd.Series(boxcox_series, index=series.index).diff(1).dropna()
535
+ test_series = transformed
536
+ except:
537
+ continue
538
+
539
+ elif method == 'detrend':
540
+ # Linear detrending
541
+ x = np.arange(len(series))
542
+ y = series.values
543
+ coeffs = np.polyfit(x, y, 1)
544
+ trend = np.polyval(coeffs, x)
545
+ transformed = pd.Series(y - trend, index=series.index)
546
+ test_series = transformed
547
+
548
+ elif method == 'combination':
549
+ # Combined method: log + differencing + detrending
550
+ if (series > 0).all():
551
+ log_series = np.log(series)
552
+ diff_series = log_series.diff(1)
553
+
554
+ # Detrending residuals
555
+ x = np.arange(len(diff_series))
556
+ y = diff_series.values
557
+ valid_mask = ~np.isnan(y)
558
+
559
+ if valid_mask.sum() > 2:
560
+ coeffs = np.polyfit(x[valid_mask], y[valid_mask], 1)
561
+ trend = np.polyval(coeffs, x)
562
+ transformed = pd.Series(y - trend, index=series.index)
563
+ test_series = transformed.dropna()
564
+ else:
565
+ test_series = diff_series.dropna()
566
+ else:
567
+ continue
568
+
569
+ # Check stationarity after transformation
570
+ if len(test_series) > 10:
571
+ adf_result = adfuller(test_series.dropna())
572
+ is_stationary = adf_result[1] < 0.05
573
+ pvalue = adf_result[1]
574
+
575
+ logger.info(f" Method: {method_name}")
576
+ logger.info(f" ADF p-value: {pvalue:.4f}")
577
+ logger.info(f" Stationary: {'✓ Yes' if is_stationary else '✗ No'}")
578
+
579
+ # Save best method
580
+ if is_stationary and pvalue < best_pvalue:
581
+ best_pvalue = pvalue
582
+ best_method = method
583
+ best_series = transformed
584
+ best_stationary = True
585
+
586
+ if pvalue < 0.01: # Very good result
587
+ break
588
+
589
+ except Exception as e:
590
+ logger.debug(f" Method {method} failed: {e}")
591
+ continue
592
+
593
+ # Save results
594
+ if best_series is not None:
595
+ new_col_name = f'{target_col}_stationary_{best_method}'
596
+
597
+ # Align indices
598
+ aligned_series = pd.Series(
599
+ best_series.values,
600
+ index=data_processed.index[-len(best_series):]
601
+ )
602
+
603
+ data_processed[new_col_name] = aligned_series
604
+
605
+ self.transformed_series[target_col] = {
606
+ 'method': best_method,
607
+ 'new_column': new_col_name,
608
+ 'pvalue': float(best_pvalue),
609
+ 'is_stationary': best_stationary,
610
+ 'original_shape': len(series),
611
+ 'transformed_shape': len(best_series)
612
+ }
613
+
614
+ self.best_transformation[target_col] = best_method
615
+
616
+ logger.info(f"\n✓ Selected method: {best_method}")
617
+ logger.info(f" Saved as '{new_col_name}'")
618
+ logger.info(f" p-value: {best_pvalue:.4f}")
619
+
620
+ return data_processed
621
+ else:
622
+ logger.warning("✗ Could not find suitable transformation for stationarisation")
623
+ return None
624
+
625
+ def get_report(self) -> Dict:
626
+ """Get stationarity report"""
627
+ return {
628
+ 'test_results': self.test_results,
629
+ 'transformed_series': self.transformed_series,
630
+ 'best_transformations': self.best_transformation
631
+ }
temp_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
validation/__init__.py ADDED
File without changes
validation/data_validator.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 12: DATA VALIDATION
3
+ # ============================================
4
+ from datetime import datetime
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Dict, List
8
+ from venv import logger
9
+
10
+ from config.config import Config
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+ class DataValidator:
15
+ """Class for data quality validation"""
16
+
17
+ def __init__(self, config: Config):
18
+ """
19
+ Initialise data validator
20
+
21
+ Parameters:
22
+ -----------
23
+ config : Config
24
+ Experiment configuration
25
+ """
26
+ self.config = config
27
+ self.validation_results = {}
28
+ self.quality_metrics = {}
29
+ self.issues_found = {}
30
+
31
+ def validate(
32
+ self,
33
+ data: pd.DataFrame,
34
+ stage: str = 'final',
35
+ rules: Dict = None,
36
+ detailed: bool = True
37
+ ) -> Dict:
38
+ """
39
+ Validate data quality
40
+
41
+ Parameters:
42
+ -----------
43
+ data : pd.DataFrame
44
+ Input data
45
+ stage : str
46
+ Validation stage: 'raw', 'processed', 'final'
47
+ rules : Dict, optional
48
+ Validation rules. If None, uses configuration defaults.
49
+ detailed : bool
50
+ Whether to perform detailed validation
51
+
52
+ Returns:
53
+ --------
54
+ Dict
55
+ Validation results
56
+ """
57
+ logger.info("\n" + "="*80)
58
+ logger.info(f"DATA VALIDATION ({stage.upper()})")
59
+ logger.info("="*80)
60
+
61
+ rules = rules or self.config.validation_rules
62
+
63
+ validation_results = {
64
+ 'stage': stage,
65
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
66
+ 'data_shape': list(data.shape),
67
+ 'basic_checks': {},
68
+ 'quality_metrics': {},
69
+ 'issues': {},
70
+ 'recommendations': [],
71
+ 'overall_score': 0,
72
+ 'status': 'PASS'
73
+ }
74
+
75
+ # Basic checks
76
+ validation_results['basic_checks'] = self._basic_checks(data, rules)
77
+
78
+ # Quality checks
79
+ validation_results['quality_metrics'] = self._quality_metrics(data, rules)
80
+
81
+ # Problem detection
82
+ if detailed:
83
+ validation_results['issues'] = self._find_issues(data, rules)
84
+
85
+ # Recommendation generation
86
+ validation_results['recommendations'] = self._generate_recommendations(
87
+ validation_results['basic_checks'],
88
+ validation_results['quality_metrics'],
89
+ validation_results['issues']
90
+ )
91
+
92
+ # Overall score calculation
93
+ validation_results['overall_score'] = self._calculate_overall_score(validation_results)
94
+
95
+ # Status determination
96
+ if validation_results['overall_score'] >= 80:
97
+ validation_results['status'] = 'PASS'
98
+ elif validation_results['overall_score'] >= 60:
99
+ validation_results['status'] = 'WARNING'
100
+ else:
101
+ validation_results['status'] = 'FAIL'
102
+
103
+ # Save results
104
+ self.validation_results[stage] = validation_results
105
+ self.quality_metrics[stage] = validation_results['quality_metrics']
106
+
107
+ # Log results
108
+ self._log_validation_results(validation_results)
109
+
110
+ return validation_results
111
+
112
+ def _basic_checks(self, data: pd.DataFrame, rules: Dict) -> Dict:
113
+ """Basic data checks"""
114
+ checks = {}
115
+
116
+ # 1. Data size check
117
+ checks['min_rows'] = {
118
+ 'value': len(data),
119
+ 'threshold': rules.get('min_rows', 100),
120
+ 'passed': len(data) >= rules.get('min_rows', 100)
121
+ }
122
+
123
+ # 2. Target variable presence check
124
+ target = self.config.target_column
125
+ checks['has_target'] = {
126
+ 'value': target in data.columns,
127
+ 'passed': target in data.columns
128
+ }
129
+
130
+ # 3. Missing values check
131
+ missing_percentage = (data.isnull().sum().sum() / data.size) * 100
132
+ checks['missing_percentage'] = {
133
+ 'value': missing_percentage,
134
+ 'threshold': rules.get('max_missing_percentage', 30),
135
+ 'passed': missing_percentage <= rules.get('max_missing_percentage', 30)
136
+ }
137
+
138
+ # 4. Duplicates check
139
+ duplicate_count = data.duplicated().sum()
140
+ duplicate_percentage = (duplicate_count / len(data)) * 100
141
+ checks['duplicates'] = {
142
+ 'value': duplicate_percentage,
143
+ 'threshold': 5, # Maximum 5% duplicates
144
+ 'passed': duplicate_percentage <= 5
145
+ }
146
+
147
+ # 5. Data types check
148
+ numeric_count = len(data.select_dtypes(include=[np.number]).columns)
149
+ checks['numeric_features'] = {
150
+ 'value': numeric_count,
151
+ 'passed': numeric_count >= 1 # At least one numeric feature required
152
+ }
153
+
154
+ return checks
155
+
156
+ def _quality_metrics(self, data: pd.DataFrame, rules: Dict) -> Dict:
157
+ """Data quality metrics"""
158
+ metrics = {}
159
+
160
+ # 1. Numeric features statistics
161
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
162
+
163
+ if len(numeric_cols) > 0:
164
+ numeric_stats = {}
165
+ for col in numeric_cols:
166
+ col_data = data[col].dropna()
167
+ if len(col_data) > 0:
168
+ numeric_stats[col] = {
169
+ 'mean': float(col_data.mean()),
170
+ 'std': float(col_data.std()),
171
+ 'skewness': float(col_data.skew()),
172
+ 'kurtosis': float(col_data.kurtosis()),
173
+ 'zeros_percentage': float((col_data == 0).sum() / len(col_data) * 100),
174
+ 'unique_percentage': float(col_data.nunique() / len(col_data) * 100)
175
+ }
176
+
177
+ metrics['numeric_statistics'] = numeric_stats
178
+
179
+ # 2. Data stability (for time series)
180
+ if isinstance(data.index, pd.DatetimeIndex):
181
+ stability_metrics = self._calculate_temporal_stability(data)
182
+ metrics['temporal_stability'] = stability_metrics
183
+
184
+ # 3. Feature informativeness
185
+ if self.config.target_column in data.columns:
186
+ informativeness = self._calculate_feature_informativeness(data)
187
+ metrics['feature_informativeness'] = informativeness
188
+
189
+ # 4. Target variable quality
190
+ target = self.config.target_column
191
+ if target in data.columns:
192
+ target_data = data[target].dropna()
193
+ if len(target_data) > 0:
194
+ target_metrics = {
195
+ 'missing_percentage': float(target_data.isnull().sum() / len(data) * 100),
196
+ 'unique_values': int(target_data.nunique()),
197
+ 'is_constant': bool(target_data.nunique() <= 1),
198
+ 'has_outliers': self._check_target_outliers(target_data),
199
+ 'distribution_type': self._identify_distribution(target_data)
200
+ }
201
+ metrics['target_quality'] = target_metrics
202
+
203
+ # 5. Class balance (for classification) - not applicable here, but kept as placeholder
204
+ metrics['class_balance'] = {'note': 'Not applicable for regression'}
205
+
206
+ return metrics
207
+
208
+ def _calculate_temporal_stability(self, data: pd.DataFrame) -> Dict:
209
+ """Calculate time series stability metrics"""
210
+ stability = {}
211
+
212
+ if not isinstance(data.index, pd.DatetimeIndex):
213
+ return stability
214
+
215
+ # Split into periods (e.g., by years)
216
+ if 'year' not in data.columns:
217
+ data_copy = data.copy()
218
+ data_copy['year'] = data_copy.index.year
219
+ else:
220
+ data_copy = data
221
+
222
+ years = sorted(data_copy['year'].unique())
223
+
224
+ if len(years) > 1:
225
+ # Statistics by years for numeric columns
226
+ year_stats = {}
227
+ for col in data.select_dtypes(include=[np.number]).columns[:5]: # First 5 columns
228
+ yearly_means = data_copy.groupby('year')[col].mean()
229
+ yearly_stds = data_copy.groupby('year')[col].std()
230
+
231
+ # Coefficient of variation between years
232
+ if yearly_means.std() > 0:
233
+ cv_between_years = yearly_means.std() / yearly_means.mean()
234
+ else:
235
+ cv_between_years = 0
236
+
237
+ year_stats[col] = {
238
+ 'yearly_means': yearly_means.to_dict(),
239
+ 'yearly_stds': yearly_stds.to_dict(),
240
+ 'cv_between_years': float(cv_between_years),
241
+ 'mean_stability': float(1 - cv_between_years) # 1 - CV, closer to 1 means more stable
242
+ }
243
+
244
+ stability['yearly_statistics'] = year_stats
245
+
246
+ # Check for time gaps
247
+ time_diff = pd.Series(data.index).diff().dropna()
248
+ if len(time_diff) > 0:
249
+ max_gap = time_diff.max()
250
+ avg_gap = time_diff.mean()
251
+ gap_std = time_diff.std()
252
+
253
+ stability['time_gaps'] = {
254
+ 'max_gap_days': float(max_gap.days if hasattr(max_gap, 'days') else max_gap),
255
+ 'avg_gap_days': float(avg_gap.days if hasattr(avg_gap, 'days') else avg_gap),
256
+ 'gap_std': float(gap_std.days if hasattr(gap_std, 'days') else gap_std),
257
+ 'has_irregular_gaps': gap_std > avg_gap * 0.5 # If standard deviation > 50% of mean
258
+ }
259
+
260
+ # Seasonal stability
261
+ if len(data) > 365:
262
+ try:
263
+ # Analyse seasonal patterns
264
+ seasonal_stability = self._analyse_seasonal_stability(data)
265
+ stability['seasonal_stability'] = seasonal_stability
266
+ except:
267
+ pass
268
+
269
+ return stability
270
+
271
+ def _analyse_seasonal_stability(self, data: pd.DataFrame) -> Dict:
272
+ """Analyse seasonal patterns stability"""
273
+ if not isinstance(data.index, pd.DatetimeIndex):
274
+ return {}
275
+
276
+ # For simplicity, analyse only target variable
277
+ target = self.config.target_column
278
+ if target not in data.columns:
279
+ return {}
280
+
281
+ series = data[target]
282
+
283
+ # Split by years and compare seasonal patterns
284
+ data_copy = data.copy()
285
+ data_copy['year'] = data_copy.index.year
286
+ data_copy['month'] = data_copy.index.month
287
+
288
+ if 'year' in data_copy.columns and 'month' in data_copy.columns:
289
+ monthly_means = data_copy.groupby(['year', 'month'])[target].mean().unstack()
290
+
291
+ if not monthly_means.empty:
292
+ # Correlation between years
293
+ yearly_corr = monthly_means.corr().mean().mean()
294
+
295
+ # Variation between years
296
+ monthly_cv = monthly_means.std() / monthly_means.mean()
297
+ avg_monthly_cv = monthly_cv.mean()
298
+
299
+ return {
300
+ 'yearly_correlation': float(yearly_corr),
301
+ 'average_monthly_cv': float(avg_monthly_cv),
302
+ 'seasonal_consistency': 'high' if yearly_corr > 0.8 and avg_monthly_cv < 0.3 else
303
+ 'medium' if yearly_corr > 0.6 else 'low'
304
+ }
305
+
306
+ return {}
307
+
308
+ def _calculate_feature_informativeness(self, data: pd.DataFrame) -> Dict:
309
+ """Calculate feature informativeness"""
310
+ informativeness = {}
311
+
312
+ target = self.config.target_column
313
+ if target not in data.columns:
314
+ return informativeness
315
+
316
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
317
+ numeric_cols = [col for col in numeric_cols if col != target]
318
+
319
+ for col in numeric_cols[:20]: # Limit number of features for analysis
320
+ try:
321
+ # Correlation with target variable
322
+ correlation = data[col].corr(data[target])
323
+
324
+ # Mutual information (approximated)
325
+ # For simplicity, use absolute correlation as informativeness measure
326
+ informativeness[col] = {
327
+ 'correlation_with_target': float(correlation),
328
+ 'abs_correlation': float(abs(correlation)),
329
+ 'informativeness': 'high' if abs(correlation) > 0.5 else
330
+ 'medium' if abs(correlation) > 0.3 else 'low'
331
+ }
332
+ except:
333
+ continue
334
+
335
+ return informativeness
336
+
337
+ def _check_target_outliers(self, target_series: pd.Series) -> Dict:
338
+ """Check target variable for outliers"""
339
+ if len(target_series) < 10:
340
+ return {'has_outliers': False, 'outlier_percentage': 0}
341
+
342
+ q1 = target_series.quantile(0.25)
343
+ q3 = target_series.quantile(0.75)
344
+ iqr = q3 - q1
345
+
346
+ if iqr > 0:
347
+ lower_bound = q1 - 1.5 * iqr
348
+ upper_bound = q3 + 1.5 * iqr
349
+
350
+ outliers = target_series[(target_series < lower_bound) | (target_series > upper_bound)]
351
+ outlier_percentage = len(outliers) / len(target_series) * 100
352
+
353
+ return {
354
+ 'has_outliers': len(outliers) > 0,
355
+ 'outlier_count': int(len(outliers)),
356
+ 'outlier_percentage': float(outlier_percentage),
357
+ 'outlier_bounds': {'lower': float(lower_bound), 'upper': float(upper_bound)}
358
+ }
359
+
360
+ return {'has_outliers': False, 'outlier_percentage': 0}
361
+
362
+ def _identify_distribution(self, series: pd.Series) -> str:
363
+ """Identify distribution type"""
364
+ if len(series) < 30:
365
+ return 'insufficient_data'
366
+
367
+ skewness = series.skew()
368
+ kurtosis = series.kurtosis()
369
+
370
+ if abs(skewness) < 0.5 and abs(kurtosis) < 1:
371
+ return 'normal_like'
372
+ elif skewness > 1:
373
+ return 'right_skewed'
374
+ elif skewness < -1:
375
+ return 'left_skewed'
376
+ elif kurtosis > 3:
377
+ return 'heavy_tailed'
378
+ elif kurtosis < 2:
379
+ return 'light_tailed'
380
+ else:
381
+ return 'unknown'
382
+
383
+ def _find_issues(self, data: pd.DataFrame, rules: Dict) -> Dict:
384
+ """Find data problems"""
385
+ issues = {
386
+ 'critical': [],
387
+ 'warning': [],
388
+ 'info': []
389
+ }
390
+
391
+ # 1. Check missing values in important features
392
+ missing_info = data.isnull().sum()
393
+ high_missing_cols = missing_info[missing_info / len(data) * 100 > 20].index.tolist()
394
+
395
+ for col in high_missing_cols:
396
+ missing_pct = missing_info[col] / len(data) * 100
397
+ if missing_pct > 50:
398
+ issues['critical'].append(f"Column '{col}': {missing_pct:.1f}% missing values (critical)")
399
+ elif missing_pct > 20:
400
+ issues['warning'].append(f"Column '{col}': {missing_pct:.1f}% missing values")
401
+
402
+ # 2. Check constant features
403
+ for col in data.columns:
404
+ if data[col].nunique() <= 1:
405
+ issues['critical'].append(f"Column '{col}': constant value")
406
+
407
+ # 3. Check feature correlation with itself (lags)
408
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
409
+ for col in numeric_cols:
410
+ if '_lag_' in col or '_diff_' in col:
411
+ base_col = col.split('_lag_')[0] if '_lag_' in col else col.split('_diff_')[0]
412
+ if base_col in numeric_cols:
413
+ correlation = data[col].corr(data[base_col])
414
+ if pd.notna(correlation) and abs(correlation) > 0.95:
415
+ issues['info'].append(f"Column '{col}': high correlation with '{base_col}' ({correlation:.3f})")
416
+
417
+ # 4. Check time gaps
418
+ if isinstance(data.index, pd.DatetimeIndex):
419
+ time_diff = pd.Series(data.index).diff().dropna()
420
+ if len(time_diff) > 0:
421
+ max_gap = time_diff.max()
422
+ if hasattr(max_gap, 'days') and max_gap.days > 30:
423
+ issues['warning'].append(f"Detected time gap: {max_gap.days} days")
424
+
425
+ # 5. Check target variable
426
+ target = self.config.target_column
427
+ if target in data.columns:
428
+ target_data = data[target].dropna()
429
+ if len(target_data) > 0:
430
+ if target_data.nunique() <= 1:
431
+ issues['critical'].append(f"Target variable '{target}': constant value")
432
+
433
+ # Check for outliers
434
+ outlier_check = self._check_target_outliers(target_data)
435
+ if outlier_check.get('has_outliers', False) and outlier_check.get('outlier_percentage', 0) > 10:
436
+ issues['warning'].append(f"Target variable '{target}': {outlier_check['outlier_percentage']:.1f}% outliers")
437
+
438
+ # 6. Check multicollinearity (simplified)
439
+ if len(numeric_cols) > 5:
440
+ corr_matrix = data[numeric_cols].corr().abs()
441
+ high_corr_pairs = []
442
+
443
+ for i in range(len(corr_matrix.columns)):
444
+ for j in range(i+1, len(corr_matrix.columns)):
445
+ if corr_matrix.iloc[i, j] > 0.9:
446
+ col1 = corr_matrix.columns[i]
447
+ col2 = corr_matrix.columns[j]
448
+ high_corr_pairs.append((col1, col2, corr_matrix.iloc[i, j]))
449
+
450
+ if len(high_corr_pairs) > 5:
451
+ issues['warning'].append(f"Detected multicollinearity: {len(high_corr_pairs)} pairs with correlation > 0.9")
452
+
453
+ return issues
454
+
455
+ def _generate_recommendations(
456
+ self,
457
+ basic_checks: Dict,
458
+ quality_metrics: Dict,
459
+ issues: Dict
460
+ ) -> List[str]:
461
+ """Generate data improvement recommendations"""
462
+ recommendations = []
463
+
464
+ # Recommendations based on basic checks
465
+ for check_name, check_info in basic_checks.items():
466
+ if not check_info.get('passed', True):
467
+ if check_name == 'min_rows':
468
+ recommendations.append(f"Increase data volume: current row count ({check_info['value']}) below minimum threshold ({check_info['threshold']})")
469
+ elif check_name == 'has_target':
470
+ recommendations.append(f"Add target variable '{self.config.target_column}' to data")
471
+ elif check_name == 'missing_percentage':
472
+ recommendations.append(f"Handle missing values: {check_info['value']:.1f}% missing exceeds threshold {check_info['threshold']}%")
473
+ elif check_name == 'duplicates':
474
+ recommendations.append(f"Remove duplicates: {check_info['value']:.1f}% duplicate rows")
475
+
476
+ # Recommendations based on issues
477
+ if issues.get('critical'):
478
+ recommendations.append("Resolve critical issues before using data")
479
+
480
+ if issues.get('warning'):
481
+ recommendations.append("Consider addressing warnings to improve data quality")
482
+
483
+ # Recommendations based on quality metrics
484
+ target_metrics = quality_metrics.get('target_quality', {})
485
+ if target_metrics.get('is_constant', False):
486
+ recommendations.append(f"Target variable '{self.config.target_column}' is constant, different target variable needed")
487
+
488
+ if target_metrics.get('has_outliers', {}).get('has_outliers', False):
489
+ outlier_pct = target_metrics['has_outliers'].get('outlier_percentage', 0)
490
+ if outlier_pct > 5:
491
+ recommendations.append(f"Handle outliers in target variable: {outlier_pct:.1f}% outliers")
492
+
493
+ # Time series stability recommendations
494
+ temporal_stability = quality_metrics.get('temporal_stability', {})
495
+ if temporal_stability.get('time_gaps', {}).get('has_irregular_gaps', False):
496
+ recommendations.append("Detected irregular time intervals, consider resampling")
497
+
498
+ return recommendations
499
+
500
+ def _calculate_overall_score(self, validation_results: Dict) -> float:
501
+ """Calculate overall data quality score"""
502
+ score = 100
503
+
504
+ # Penalties for basic checks
505
+ basic_checks = validation_results.get('basic_checks', {})
506
+ for check_name, check_info in basic_checks.items():
507
+ if not check_info.get('passed', True):
508
+ if check_name == 'min_rows':
509
+ score -= 30
510
+ elif check_name == 'has_target':
511
+ score -= 50
512
+ elif check_name == 'missing_percentage':
513
+ missing_pct = check_info.get('value', 0)
514
+ if missing_pct > 50:
515
+ score -= 40
516
+ elif missing_pct > 20:
517
+ score -= 20
518
+ elif missing_pct > 5:
519
+ score -= 10
520
+ elif check_name == 'duplicates':
521
+ duplicate_pct = check_info.get('value', 0)
522
+ if duplicate_pct > 20:
523
+ score -= 30
524
+ elif duplicate_pct > 10:
525
+ score -= 15
526
+ elif duplicate_pct > 5:
527
+ score -= 5
528
+
529
+ # Penalties for issues
530
+ issues = validation_results.get('issues', {})
531
+ if issues.get('critical'):
532
+ score -= len(issues['critical']) * 20
533
+
534
+ if issues.get('warning'):
535
+ score -= len(issues['warning']) * 5
536
+
537
+ # Bonuses for good metrics
538
+ quality_metrics = validation_results.get('quality_metrics', {})
539
+ target_metrics = quality_metrics.get('target_quality', {})
540
+
541
+ if not target_metrics.get('is_constant', True):
542
+ score += 10
543
+
544
+ if target_metrics.get('missing_percentage', 100) < 1:
545
+ score += 5
546
+
547
+ # Limit score to 0-100 range
548
+ return max(0, min(100, score))
549
+
550
+ def _log_validation_results(self, validation_results: Dict) -> None:
551
+ """Log validation results"""
552
+ stage = validation_results['stage']
553
+ status = validation_results['status']
554
+ score = validation_results['overall_score']
555
+
556
+ logger.info(f"VALIDATION RESULTS ({stage}):")
557
+ logger.info(f" Status: {status}")
558
+ logger.info(f" Overall score: {score}/100")
559
+ logger.info(f" Data shape: {validation_results['data_shape'][0]}x{validation_results['data_shape'][1]}")
560
+
561
+ # Basic checks
562
+ logger.info("\nBASIC CHECKS:")
563
+ for check_name, check_info in validation_results['basic_checks'].items():
564
+ status_icon = "✓" if check_info.get('passed', True) else "✗"
565
+ logger.info(f" {status_icon} {check_name}: {check_info.get('value', 'N/A')}")
566
+
567
+ # Issues
568
+ issues = validation_results['issues']
569
+ if any(issues.values()):
570
+ logger.info("\nDETECTED ISSUES:")
571
+ for severity, issue_list in issues.items():
572
+ if issue_list:
573
+ logger.info(f" {severity.upper()}:")
574
+ for issue in issue_list[:5]: # Show only first 5 issues of each type
575
+ logger.info(f" - {issue}")
576
+ if len(issue_list) > 5:
577
+ logger.info(f" ... and {len(issue_list) - 5} more issues")
578
+ else:
579
+ logger.info("\n✓ No issues detected")
580
+
581
+ # Recommendations
582
+ recommendations = validation_results['recommendations']
583
+ if recommendations:
584
+ logger.info("\nRECOMMENDATIONS:")
585
+ for i, rec in enumerate(recommendations, 1):
586
+ logger.info(f" {i}. {rec}")
587
+
588
+ # Conclusion
589
+ if status == 'PASS':
590
+ logger.info("\n✓ Data passed validation and is ready for use")
591
+ elif status == 'WARNING':
592
+ logger.info("\n⚠ Data requires attention, there are issues to address")
593
+ else:
594
+ logger.info("\n✗ Data requires significant improvement before use")
595
+
596
+ def generate_report(self, stage: str = 'final') -> Dict:
597
+ """Generate detailed validation report"""
598
+ if stage not in self.validation_results:
599
+ return {}
600
+
601
+ report = self.validation_results[stage].copy()
602
+
603
+ # Add metadata
604
+ report['config'] = self.config.to_dict()
605
+ report['validator_version'] = '1.0'
606
+ report['generation_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
607
+
608
+ # Add detailed metrics
609
+ quality_metrics = report.get('quality_metrics', {})
610
+
611
+ if 'numeric_statistics' in quality_metrics:
612
+ # Numeric features summary
613
+ numeric_stats = quality_metrics['numeric_statistics']
614
+ report['numeric_summary'] = {
615
+ 'total_numeric_features': len(numeric_stats),
616
+ 'features_with_high_skewness': sum(1 for s in numeric_stats.values() if abs(s.get('skewness', 0)) > 1),
617
+ 'features_with_high_kurtosis': sum(1 for s in numeric_stats.values() if abs(s.get('kurtosis', 0)) > 3),
618
+ 'features_with_many_zeros': sum(1 for s in numeric_stats.values() if s.get('zeros_percentage', 0) > 50)
619
+ }
620
+
621
+ return report
622
+
623
+ def save_report(self, stage: str = 'final', path: str = None) -> None:
624
+ """Save validation report to file"""
625
+ if stage not in self.validation_results:
626
+ logger.warning(f"Report for stage '{stage}' not found")
627
+ return
628
+
629
+ report = self.generate_report(stage)
630
+
631
+ if path is None:
632
+ path = f'{self.config.results_dir}/reports/validation_report_{stage}.json'
633
+
634
+ # Create directory if needed
635
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
636
+
637
+ # Custom JSON encoder
638
+ class NumpyEncoder(json.JSONEncoder):
639
+ def default(self, obj):
640
+ if isinstance(obj, (np.integer, np.floating)):
641
+ if np.isnan(obj):
642
+ return None
643
+ return float(obj)
644
+ elif isinstance(obj, np.bool_):
645
+ return bool(obj)
646
+ elif isinstance(obj, np.ndarray):
647
+ return obj.tolist()
648
+ elif isinstance(obj, pd.Timestamp):
649
+ return obj.strftime('%Y-%m-%d %H:%M:%S')
650
+ return super().default(obj)
651
+
652
+ with open(path, 'w', encoding='utf-8') as f:
653
+ json.dump(report, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
654
+
655
+ logger.info(f"✓ Validation report saved: {path}")
visualization/__init__.py ADDED
File without changes
visualization/visualization_manager.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # CLASS 13: VISUALISATION MANAGER (UPDATED)
3
+ # ============================================
4
+ import os
5
+ from datetime import datetime
6
+ import json
7
+ from typing import Dict, List, Optional, Tuple, Union, Any
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+ from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ from scipy.stats import gaussian_kde
15
+ import matplotlib
16
+ matplotlib.use('Agg') # Use non-display backend
17
+
18
+ from config.config import Config
19
+ import logging
20
+
21
+ # Logging setup
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class VisualisationManager:
27
+ """Class for managing all visualisations"""
28
+
29
+ def __init__(self, config: Config):
30
+ """
31
+ Initialise visualisation manager
32
+
33
+ Parameters:
34
+ -----------
35
+ config : Config
36
+ Experiment configuration
37
+ """
38
+ self.config = config
39
+ self.plots_generated = {}
40
+ self.plot_files = {}
41
+ self.figure_count = 0
42
+
43
+ # Create directory structure for saving plots
44
+ self._create_directory_structure()
45
+
46
+ def _create_directory_structure(self) -> None:
47
+ """Create directory structure for saving plots"""
48
+ base_dir = self.config.results_dir
49
+
50
+ # Main plot directories
51
+ self.plots_dir = os.path.join(base_dir, "plots")
52
+ self.correlations_dir = os.path.join(base_dir, "plots", "correlations")
53
+ self.distributions_dir = os.path.join(base_dir, "plots", "distributions")
54
+ self.features_dir = os.path.join(base_dir, "plots", "features")
55
+ self.time_series_dir = os.path.join(base_dir, "plots", "time_series")
56
+ self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing")
57
+ self.summary_dir = os.path.join(base_dir, "plots", "summary")
58
+ self.reports_dir = os.path.join(base_dir, "reports")
59
+
60
+ # Create directories
61
+ directories = [
62
+ self.plots_dir,
63
+ self.correlations_dir,
64
+ self.distributions_dir,
65
+ self.features_dir,
66
+ self.time_series_dir,
67
+ self.preprocessing_dir,
68
+ self.summary_dir,
69
+ self.reports_dir
70
+ ]
71
+
72
+ for directory in directories:
73
+ os.makedirs(directory, exist_ok=True)
74
+ logger.debug(f"Created directory: {directory}")
75
+
76
+ def _save_figure(self, fig: plt.Figure, filename: str,
77
+ subdirectory: str = None, dpi: int = 300) -> str:
78
+ """
79
+ Save plot and close it
80
+
81
+ Parameters:
82
+ -----------
83
+ fig : matplotlib.figure.Figure
84
+ Plot figure object
85
+ filename : str
86
+ Filename for saving
87
+ subdirectory : str, optional
88
+ Subdirectory for saving
89
+ dpi : int
90
+ Save quality
91
+
92
+ Returns:
93
+ --------
94
+ str : full path to saved file
95
+ """
96
+ if not filename.endswith('.png'):
97
+ filename = f"{filename}.png"
98
+
99
+ if subdirectory:
100
+ save_dir = os.path.join(self.plots_dir, subdirectory)
101
+ os.makedirs(save_dir, exist_ok=True)
102
+ else:
103
+ save_dir = self.plots_dir
104
+
105
+ filepath = os.path.join(save_dir, filename)
106
+
107
+ try:
108
+ fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white')
109
+ logger.info(f"✓ Plot saved: {filepath}")
110
+ except Exception as e:
111
+ logger.error(f"✗ Error saving plot {filename}: {e}")
112
+ filepath = None
113
+
114
+ # Close plot without display
115
+ plt.close(fig)
116
+
117
+ return filepath
118
+
119
+ # ============================================
120
+ # MAIN VISUALISATION METHODS
121
+ # ============================================
122
+
123
+ def create_summary_dashboard(
124
+ self,
125
+ data: pd.DataFrame,
126
+ preprocessing_stages: Dict = None,
127
+ filename: str = "summary_dashboard"
128
+ ) -> str:
129
+ """
130
+ Create summary visualisation dashboard
131
+
132
+ Parameters:
133
+ -----------
134
+ data : pd.DataFrame
135
+ Data for visualisation
136
+ preprocessing_stages : Dict, optional
137
+ Preprocessing stages information
138
+ filename : str
139
+ Filename for saving
140
+
141
+ Returns:
142
+ --------
143
+ str : path to saved file or None if error
144
+ """
145
+ logger.info("\n" + "="*80)
146
+ logger.info("CREATING SUMMARY DASHBOARD")
147
+ logger.info("="*80)
148
+
149
+ target_col = self.config.target_column
150
+
151
+ try:
152
+ # Create large dashboard
153
+ fig = plt.figure(figsize=(20, 24))
154
+ gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3)
155
+
156
+ # 1. Time series of target variable
157
+ ax1 = fig.add_subplot(gs[0, :2])
158
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
159
+ ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7)
160
+ ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold')
161
+ ax1.set_xlabel('Date', fontsize=10)
162
+ ax1.set_ylabel(target_col, fontsize=10)
163
+ ax1.grid(True, alpha=0.3)
164
+ ax1.tick_params(axis='x', rotation=45)
165
+ else:
166
+ ax1.text(0.5, 0.5, 'No time series data available',
167
+ ha='center', va='center', transform=ax1.transAxes)
168
+
169
+ # 2. Target variable distribution
170
+ ax2 = fig.add_subplot(gs[0, 2:])
171
+ if target_col in data.columns:
172
+ values = data[target_col].dropna()
173
+ if len(values) > 0:
174
+ ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green')
175
+ ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold')
176
+ ax2.set_xlabel(target_col, fontsize=10)
177
+ ax2.set_ylabel('Frequency', fontsize=10)
178
+ ax2.grid(True, alpha=0.3)
179
+ else:
180
+ ax2.text(0.5, 0.5, 'No data for distribution',
181
+ ha='center', va='center', transform=ax2.transAxes)
182
+
183
+ # 3. Correlation matrix (top features)
184
+ ax3 = fig.add_subplot(gs[1, :])
185
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
186
+ if len(numeric_cols) > 1:
187
+ display_cols = list(numeric_cols[:15])
188
+ if target_col not in display_cols and target_col in data.columns:
189
+ display_cols = [target_col] + [c for c in display_cols if c != target_col][:14]
190
+
191
+ corr_matrix = data[display_cols].corr()
192
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
193
+
194
+ im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
195
+ ax3.set_title('Correlation Matrix (Top 15 Features)',
196
+ fontsize=12, fontweight='bold')
197
+ ax3.set_xticks(range(len(display_cols)))
198
+ ax3.set_yticks(range(len(display_cols)))
199
+ ax3.set_xticklabels(display_cols, rotation=90, fontsize=8)
200
+ ax3.set_yticklabels(display_cols, fontsize=8)
201
+ plt.colorbar(im, ax=ax3, shrink=0.8)
202
+
203
+ # 4. Seasonal patterns
204
+ ax4 = fig.add_subplot(gs[2, :2])
205
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
206
+ data_copy = data.copy()
207
+ data_copy['month'] = data_copy.index.month
208
+
209
+ monthly_avg = data_copy.groupby('month')[target_col].mean()
210
+ colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg)))
211
+ ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black')
212
+ ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold')
213
+ ax4.set_xlabel('Month', fontsize=10)
214
+ ax4.set_ylabel(f'Average {target_col}', fontsize=10)
215
+ ax4.set_xticks(range(1, 13))
216
+ month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
217
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
218
+ ax4.set_xticklabels(month_names)
219
+ ax4.grid(True, alpha=0.3, axis='y')
220
+
221
+ # 5. Weekly patterns
222
+ ax5 = fig.add_subplot(gs[2, 2:])
223
+ if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
224
+ data_copy = data.copy()
225
+ data_copy['dayofweek'] = data_copy.index.dayofweek
226
+
227
+ daily_avg = data_copy.groupby('dayofweek')[target_col].mean()
228
+ colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg)))
229
+ ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black')
230
+ ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold')
231
+ ax5.set_xlabel('Day of Week', fontsize=10)
232
+ ax5.set_ylabel(f'Average {target_col}', fontsize=10)
233
+ ax5.set_xticks(range(7))
234
+ ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
235
+ ax5.grid(True, alpha=0.3, axis='y')
236
+
237
+ # 6. Trend and seasonality
238
+ ax6 = fig.add_subplot(gs[3, :])
239
+ if target_col in data.columns and len(data) > 30:
240
+ try:
241
+ window_size = min(365, len(data) // 10)
242
+ if window_size >= 7:
243
+ rolling_mean = data[target_col].rolling(window=window_size, center=True).mean()
244
+ rolling_std = data[target_col].rolling(window=window_size, center=True).std()
245
+
246
+ ax6.plot(data.index, data[target_col], alpha=0.5,
247
+ label='Original Series', linewidth=0.5, color='blue')
248
+ ax6.plot(rolling_mean.index, rolling_mean,
249
+ label=f'Rolling Mean ({window_size} days)',
250
+ color='red', linewidth=2)
251
+ ax6.fill_between(rolling_mean.index,
252
+ rolling_mean - rolling_std,
253
+ rolling_mean + rolling_std,
254
+ alpha=0.2, color='red')
255
+
256
+ ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold')
257
+ ax6.set_xlabel('Date', fontsize=10)
258
+ ax6.set_ylabel(target_col, fontsize=10)
259
+ ax6.legend(fontsize=9, loc='upper left')
260
+ ax6.grid(True, alpha=0.3)
261
+ else:
262
+ ax6.text(0.5, 0.5, 'Insufficient data for trend analysis',
263
+ ha='center', va='center', transform=ax6.transAxes)
264
+ except Exception as e:
265
+ logger.warning(f"Error plotting trend: {e}")
266
+ ax6.text(0.5, 0.5, 'Error plotting trend',
267
+ ha='center', va='center', transform=ax6.transAxes)
268
+
269
+ # 7. Preprocessing statistics
270
+ if preprocessing_stages:
271
+ ax7 = fig.add_subplot(gs[4, :2])
272
+
273
+ stages = list(preprocessing_stages.keys())
274
+ values = list(preprocessing_stages.values())
275
+
276
+ colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages)))
277
+ bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black')
278
+ ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold')
279
+ ax7.set_xlabel('Processing Stage', fontsize=10)
280
+ ax7.set_ylabel('Value', fontsize=10)
281
+ ax7.set_xticks(range(len(stages)))
282
+ ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages],
283
+ rotation=45, ha='right', fontsize=9)
284
+ ax7.grid(True, alpha=0.3, axis='y')
285
+
286
+ # Add values on bars
287
+ for bar, value in zip(bars, values):
288
+ height = bar.get_height()
289
+ ax7.text(bar.get_x() + bar.get_width()/2., height,
290
+ f'{value:.2f}', ha='center', va='bottom', fontsize=8)
291
+
292
+ # 8. Data information
293
+ ax8 = fig.add_subplot(gs[4, 2:])
294
+ ax8.axis('off')
295
+
296
+ info_text = []
297
+ info_text.append("GENERAL CHARACTERISTICS:")
298
+ info_text.append(f"• Number of records: {len(data):,}")
299
+ info_text.append(f"• Number of features: {len(data.columns)}")
300
+
301
+ if isinstance(data.index, pd.DatetimeIndex):
302
+ info_text.append(f"• Period: {data.index.min().strftime('%Y-%m-%d')} - "
303
+ f"{data.index.max().strftime('%Y-%m-%d')}")
304
+ info_text.append(f"• Days of data: {(data.index.max() - data.index.min()).days}")
305
+
306
+ if target_col in data.columns:
307
+ target_stats = data[target_col].describe()
308
+ info_text.append(f"\nTARGET VARIABLE '{target_col}':")
309
+ info_text.append(f"• Mean: {target_stats['mean']:.2f}")
310
+ info_text.append(f"• Standard deviation: {target_stats['std']:.2f}")
311
+ info_text.append(f"• Minimum: {target_stats['min']:.2f}")
312
+ info_text.append(f"• 25%: {target_stats['25%']:.2f}")
313
+ info_text.append(f"• 50% (median): {target_stats['50%']:.2f}")
314
+ info_text.append(f"• 75%: {target_stats['75%']:.2f}")
315
+ info_text.append(f"• Maximum: {target_stats['max']:.2f}")
316
+
317
+ info_text.append(f"\nDATA TYPES:")
318
+ for dtype, count in data.dtypes.value_counts().items():
319
+ info_text.append(f"• {dtype}: {count} columns")
320
+
321
+ missing_info = data.isnull().sum()
322
+ missing_total = missing_info.sum()
323
+ missing_percent = missing_total / data.size * 100
324
+ info_text.append(f"\nMISSING VALUES:")
325
+ info_text.append(f"• Total missing: {missing_total:,}")
326
+ info_text.append(f"• Missing percentage: {missing_percent:.2f}%")
327
+
328
+ if missing_total > 0:
329
+ top_missing = missing_info.nlargest(5)
330
+ info_text.append(f"• Top 5 columns with missing values:")
331
+ for col, count in top_missing.items():
332
+ percent = count / len(data) * 100
333
+ info_text.append(f" {col}: {count} ({percent:.1f}%)")
334
+
335
+ ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes,
336
+ fontsize=8, verticalalignment='top', fontfamily='monospace')
337
+
338
+ # 9. Autocorrelation plot
339
+ ax9 = fig.add_subplot(gs[5, :2])
340
+ if target_col in data.columns:
341
+ try:
342
+ series = data[target_col].dropna()
343
+ if len(series) > 50:
344
+ plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05)
345
+ ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold')
346
+ ax9.set_xlabel('Lag', fontsize=10)
347
+ ax9.set_ylabel('Autocorrelation', fontsize=10)
348
+ ax9.grid(True, alpha=0.3)
349
+ else:
350
+ ax9.text(0.5, 0.5, 'Insufficient data for ACF',
351
+ ha='center', va='center', transform=ax9.transAxes)
352
+ except Exception as e:
353
+ logger.warning(f"Error plotting ACF: {e}")
354
+ ax9.text(0.5, 0.5, 'Error calculating ACF',
355
+ ha='center', va='center', transform=ax9.transAxes)
356
+
357
+ # 10. Partial autocorrelation plot
358
+ ax10 = fig.add_subplot(gs[5, 2:])
359
+ if target_col in data.columns:
360
+ try:
361
+ series = data[target_col].dropna()
362
+ if len(series) > 50:
363
+ plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05)
364
+ ax10.set_title('Partial Autocorrelation Function (PACF)',
365
+ fontsize=12, fontweight='bold')
366
+ ax10.set_xlabel('Lag', fontsize=10)
367
+ ax10.set_ylabel('Partial Autocorrelation', fontsize=10)
368
+ ax10.grid(True, alpha=0.3)
369
+ else:
370
+ ax10.text(0.5, 0.5, 'Insufficient data for PACF',
371
+ ha='center', va='center', transform=ax10.transAxes)
372
+ except Exception as e:
373
+ logger.warning(f"Error plotting PACF: {e}")
374
+ ax10.text(0.5, 0.5, 'Error calculating PACF',
375
+ ha='center', va='center', transform=ax10.transAxes)
376
+
377
+ plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98)
378
+ plt.tight_layout()
379
+
380
+ # Save
381
+ filepath = self._save_figure(fig, filename, "summary")
382
+ self.plot_files['summary_dashboard'] = filepath
383
+ return filepath
384
+
385
+ except Exception as e:
386
+ logger.error(f"Error creating summary dashboard: {e}")
387
+ return None
388
+
389
+ # ============================================
390
+ # SPECIFIC METHODS FOR SAVING YOUR PLOTS
391
+ # ============================================
392
+
393
+ def save_data_split_plot(self, filename: str = "data_split.png") -> str:
394
+ """
395
+ Save data split plot
396
+
397
+ Parameters:
398
+ -----------
399
+ filename : str
400
+ Filename for saving
401
+
402
+ Returns:
403
+ --------
404
+ str : path to saved file
405
+ """
406
+ try:
407
+ fig = plt.gcf() # Get current figure
408
+ filepath = self._save_figure(fig, filename, "time_series")
409
+ self.plot_files['data_split'] = filepath
410
+ return filepath
411
+ except Exception as e:
412
+ logger.error(f"Error saving data_split plot: {e}")
413
+ return None
414
+
415
+ def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str:
416
+ """
417
+ Save feature selection correlation plot
418
+
419
+ Parameters:
420
+ -----------
421
+ filename : str
422
+ Filename for saving
423
+
424
+ Returns:
425
+ --------
426
+ str : path to saved file
427
+ """
428
+ try:
429
+ fig = plt.gcf() # Get current figure
430
+ filepath = self._save_figure(fig, filename, "correlations")
431
+ self.plot_files['feature_selection_correlation'] = filepath
432
+ return filepath
433
+ except Exception as e:
434
+ logger.error(f"Error saving feature_selection_correlation plot: {e}")
435
+ return None
436
+
437
+ def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str:
438
+ """
439
+ Save missing values analysis plot
440
+
441
+ Parameters:
442
+ -----------
443
+ filename : str
444
+ Filename for saving
445
+
446
+ Returns:
447
+ --------
448
+ str : path to saved file
449
+ """
450
+ try:
451
+ fig = plt.gcf() # Get current figure
452
+ filepath = self._save_figure(fig, filename, "preprocessing")
453
+ self.plot_files['missing_values_analysis'] = filepath
454
+ return filepath
455
+ except Exception as e:
456
+ logger.error(f"Error saving missing_values_analysis plot: {e}")
457
+ return None
458
+
459
+ def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str:
460
+ """
461
+ Save outlier handling results plot
462
+
463
+ Parameters:
464
+ -----------
465
+ filename : str
466
+ Filename for saving
467
+
468
+ Returns:
469
+ --------
470
+ str : path to saved file
471
+ """
472
+ try:
473
+ fig = plt.gcf() # Get current figure
474
+ filepath = self._save_figure(fig, filename, "preprocessing")
475
+ self.plot_files['outlier_handling_results'] = filepath
476
+ return filepath
477
+ except Exception as e:
478
+ logger.error(f"Error saving outlier_handling_results plot: {e}")
479
+ return None
480
+
481
+ def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str:
482
+ """
483
+ Save outliers analysis plot
484
+
485
+ Parameters:
486
+ -----------
487
+ filename : str
488
+ Filename for saving
489
+
490
+ Returns:
491
+ --------
492
+ str : path to saved file
493
+ """
494
+ try:
495
+ fig = plt.gcf() # Get current figure
496
+ filepath = self._save_figure(fig, filename, "preprocessing")
497
+ self.plot_files['outliers_analysis'] = filepath
498
+ return filepath
499
+ except Exception as e:
500
+ logger.error(f"Error saving outliers_analysis plot: {e}")
501
+ return None
502
+
503
+ def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str:
504
+ """
505
+ Save scaling results plot
506
+
507
+ Parameters:
508
+ -----------
509
+ filename : str
510
+ Filename for saving
511
+
512
+ Returns:
513
+ --------
514
+ str : path to saved file
515
+ """
516
+ try:
517
+ fig = plt.gcf() # Get current figure
518
+ filepath = self._save_figure(fig, filename, "preprocessing")
519
+ self.plot_files['scaling_results'] = filepath
520
+ return filepath
521
+ except Exception as e:
522
+ logger.error(f"Error saving scaling_results plot: {e}")
523
+ return None
524
+
525
+ def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str:
526
+ """
527
+ Save stationarity analysis plot
528
+
529
+ Parameters:
530
+ -----------
531
+ filename : str
532
+ Filename for saving
533
+
534
+ Returns:
535
+ --------
536
+ str : path to saved file
537
+ """
538
+ try:
539
+ fig = plt.gcf() # Get current figure
540
+ filepath = self._save_figure(fig, filename, "time_series")
541
+ self.plot_files['stationarity_analysis'] = filepath
542
+ return filepath
543
+ except Exception as e:
544
+ logger.error(f"Error saving stationarity_analysis plot: {e}")
545
+ return None
546
+
547
+ def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str:
548
+ """
549
+ Save temporal outliers plot
550
+
551
+ Parameters:
552
+ -----------
553
+ filename : str
554
+ Filename for saving
555
+
556
+ Returns:
557
+ --------
558
+ str : path to saved file
559
+ """
560
+ try:
561
+ fig = plt.gcf() # Get current figure
562
+ filepath = self._save_figure(fig, filename, "time_series")
563
+ self.plot_files['temporal_outliers'] = filepath
564
+ return filepath
565
+ except Exception as e:
566
+ logger.error(f"Error saving temporal_outliers plot: {e}")
567
+ return None
568
+
569
+ # ============================================
570
+ # UNIVERSAL METHOD FOR SAVING ANY PLOT
571
+ # ============================================
572
+
573
+ def save_current_plot(self, filename: str, subdirectory: str = None) -> str:
574
+ """
575
+ Universal method for saving current plot
576
+
577
+ Parameters:
578
+ -----------
579
+ filename : str
580
+ Filename for saving
581
+ subdirectory : str, optional
582
+ Subdirectory for saving
583
+
584
+ Returns:
585
+ --------
586
+ str : path to saved file
587
+ """
588
+ try:
589
+ fig = plt.gcf() # Get current figure
590
+ filepath = self._save_figure(fig, filename, subdirectory)
591
+
592
+ # Save plot information
593
+ plot_key = filename.replace('.png', '').replace('.jpg', '')
594
+ self.plot_files[plot_key] = filepath
595
+
596
+ return filepath
597
+ except Exception as e:
598
+ logger.error(f"Error saving plot {filename}: {e}")
599
+ return None
600
+
601
+ # ============================================
602
+ # ADDITIONAL VISUALISATION METHODS
603
+ # ============================================
604
+
605
+ def create_feature_importance_plot(
606
+ self,
607
+ feature_importance: Dict,
608
+ top_n: int = 20,
609
+ filename: str = "feature_importance"
610
+ ) -> str:
611
+ """
612
+ Create feature importance plot
613
+
614
+ Parameters:
615
+ -----------
616
+ feature_importance : Dict
617
+ Dictionary with feature importance
618
+ top_n : int
619
+ Number of top features to display
620
+ filename : str
621
+ Filename for saving
622
+
623
+ Returns:
624
+ --------
625
+ str : path to saved file or None if error
626
+ """
627
+ if not feature_importance:
628
+ logger.warning("No feature importance data for visualisation")
629
+ return None
630
+
631
+ try:
632
+ # Convert to Series and sort
633
+ importance_series = pd.Series(feature_importance).sort_values(ascending=False)
634
+ top_features = importance_series.head(top_n)
635
+
636
+ # Create plot
637
+ fig, ax = plt.subplots(figsize=(12, 8))
638
+
639
+ y_pos = np.arange(len(top_features))
640
+ colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features)))
641
+
642
+ bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black')
643
+ ax.set_yticks(y_pos)
644
+ ax.set_yticklabels(top_features.index, fontsize=10)
645
+ ax.invert_yaxis()
646
+ ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold')
647
+ ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold')
648
+ ax.grid(True, alpha=0.3, axis='x')
649
+
650
+ # Add values on bars
651
+ for i, (bar, value) in enumerate(zip(bars, top_features.values)):
652
+ width = bar.get_width()
653
+ ax.text(width * 1.01, bar.get_y() + bar.get_height()/2,
654
+ f'{value:.4f}', va='center', fontsize=9, fontweight='bold')
655
+
656
+ # Add additional information
657
+ plt.text(0.02, 0.98, f'Total features: {len(importance_series)}',
658
+ transform=fig.transFigure, fontsize=9, verticalalignment='top')
659
+
660
+ plt.tight_layout()
661
+
662
+ # Save
663
+ filepath = self._save_figure(fig, filename, "features")
664
+ self.plot_files['feature_importance'] = filepath
665
+ return filepath
666
+
667
+ except Exception as e:
668
+ logger.error(f"Error creating feature importance plot: {e}")
669
+ return None
670
+
671
+ def create_correlation_heatmap(
672
+ self,
673
+ data: pd.DataFrame,
674
+ top_n: int = 20,
675
+ filename: str = "correlation_heatmap"
676
+ ) -> Tuple[str, Optional[str]]:
677
+ """
678
+ Create correlation heatmap
679
+
680
+ Parameters:
681
+ -----------
682
+ data : pd.DataFrame
683
+ Data for analysis
684
+ top_n : int
685
+ Number of top features to display
686
+ filename : str
687
+ Filename for saving
688
+
689
+ Returns:
690
+ --------
691
+ Tuple[str, Optional[str]]:
692
+ (path to main heatmap, path to target correlation heatmap)
693
+ """
694
+ target_col = self.config.target_column
695
+
696
+ try:
697
+ numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
698
+
699
+ if len(numeric_cols) < 2:
700
+ logger.warning("Insufficient numeric features for correlation analysis")
701
+ return None, None
702
+
703
+ # Create two heatmaps
704
+
705
+ # 1. Main correlation heatmap between all features
706
+ main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename)
707
+
708
+ # 2. Target correlation heatmap
709
+ target_filepath = None
710
+ if target_col in data.columns and target_col in numeric_cols:
711
+ target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename)
712
+
713
+ return main_filepath, target_filepath
714
+
715
+ except Exception as e:
716
+ logger.error(f"Error creating correlation heatmap: {e}")
717
+ return None, None
718
+
719
+ def _create_main_correlation_heatmap(
720
+ self,
721
+ data: pd.DataFrame,
722
+ numeric_cols: List[str],
723
+ top_n: int,
724
+ filename: str
725
+ ) -> str:
726
+ """Create main correlation heatmap"""
727
+ # Limit number of features for better readability
728
+ if len(numeric_cols) > top_n:
729
+ # Select features with highest variance
730
+ variances = data[numeric_cols].var().sort_values(ascending=False)
731
+ selected_cols = variances.head(top_n).index.tolist()
732
+ else:
733
+ selected_cols = numeric_cols
734
+
735
+ # Calculate correlation
736
+ corr_matrix = data[selected_cols].corr()
737
+
738
+ fig, ax = plt.subplots(figsize=(14, 12))
739
+
740
+ # Mask for upper triangle
741
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
742
+
743
+ # Create heatmap
744
+ sns.heatmap(
745
+ corr_matrix,
746
+ annot=True,
747
+ fmt='.2f',
748
+ cmap='coolwarm',
749
+ center=0,
750
+ square=True,
751
+ mask=mask,
752
+ cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'},
753
+ linewidths=0.5,
754
+ linecolor='white',
755
+ ax=ax,
756
+ annot_kws={'size': 8}
757
+ )
758
+
759
+ ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})',
760
+ fontsize=14, fontweight='bold', pad=20)
761
+
762
+ plt.tight_layout()
763
+
764
+ # Save
765
+ filepath = self._save_figure(fig, filename, "correlations")
766
+ self.plot_files['correlation_heatmap_main'] = filepath
767
+ return filepath
768
+
769
+ def _create_target_correlation_heatmap(
770
+ self,
771
+ data: pd.DataFrame,
772
+ target_col: str,
773
+ numeric_cols: List[str],
774
+ filename: str
775
+ ) -> str:
776
+ """Create target correlation heatmap"""
777
+ # Calculate correlations with target variable
778
+ correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False)
779
+
780
+ # Exclude target variable itself
781
+ correlations = correlations[correlations.index != target_col]
782
+
783
+ # Take top 15 features
784
+ top_features = correlations.head(15)
785
+
786
+ fig, ax = plt.subplots(figsize=(10, 8))
787
+
788
+ colors = ['red' if x < 0 else 'green' for x in top_features.values]
789
+ bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black')
790
+
791
+ ax.set_yticks(range(len(top_features)))
792
+ ax.set_yticklabels(top_features.index, fontsize=10)
793
+ ax.invert_yaxis()
794
+ ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold')
795
+ ax.set_title(f'Feature Correlations with Target Variable "{target_col}"',
796
+ fontsize=14, fontweight='bold', pad=20)
797
+ ax.grid(True, alpha=0.3, axis='x')
798
+ ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
799
+
800
+ # Add values on bars
801
+ for bar, value in zip(bars, top_features.values):
802
+ width = bar.get_width()
803
+ ax.text(width + (0.01 if width >= 0 else -0.04),
804
+ bar.get_y() + bar.get_height()/2,
805
+ f'{value:.3f}',
806
+ va='center',
807
+ ha='left' if width >= 0 else 'right',
808
+ fontsize=9,
809
+ fontweight='bold',
810
+ color='black')
811
+
812
+ plt.tight_layout()
813
+
814
+ # Save
815
+ target_filename = f"{filename}_with_target"
816
+ filepath = self._save_figure(fig, target_filename, "correlations")
817
+ self.plot_files['correlation_with_target'] = filepath
818
+ return filepath
819
+
820
+ def create_distribution_comparison(
821
+ self,
822
+ original_data: pd.DataFrame,
823
+ processed_data: pd.DataFrame,
824
+ columns: List[str] = None,
825
+ max_columns: int = 12,
826
+ filename: str = "distribution_comparison"
827
+ ) -> str:
828
+ """
829
+ Compare distributions before and after processing
830
+
831
+ Parameters:
832
+ -----------
833
+ original_data : pd.DataFrame
834
+ Original data
835
+ processed_data : pd.DataFrame
836
+ Processed data
837
+ columns : List[str], optional
838
+ List of columns to compare
839
+ max_columns : int
840
+ Maximum number of columns to display
841
+ filename : str
842
+ Filename for saving
843
+
844
+ Returns:
845
+ --------
846
+ str : path to saved file or None if error
847
+ """
848
+ try:
849
+ if columns is None:
850
+ # Select numeric columns common to both datasets
851
+ numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns
852
+ numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns
853
+ common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed))
854
+
855
+ # Sort by variance in original data
856
+ variances = original_data[common_cols].var().sort_values(ascending=False)
857
+ columns = variances.head(max_columns).index.tolist()
858
+
859
+ n_cols = min(4, len(columns))
860
+ n_rows = (len(columns) + n_cols - 1) // n_cols
861
+
862
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5))
863
+ fig.suptitle('Distribution Comparison Before and After Processing',
864
+ fontsize=16, fontweight='bold', y=0.98)
865
+
866
+ if n_rows == 1 and n_cols == 1:
867
+ axes = np.array([axes])
868
+ axes = axes.flat if hasattr(axes, 'flat') else [axes]
869
+
870
+ for idx, col in enumerate(columns):
871
+ if idx >= len(axes):
872
+ break
873
+
874
+ ax = axes[idx]
875
+
876
+ if col in original_data.columns and col in processed_data.columns:
877
+ original_values = original_data[col].dropna()
878
+ processed_values = processed_data[col].dropna()
879
+
880
+ if len(original_values) > 0 and len(processed_values) > 0:
881
+ # Use common bins for comparison
882
+ all_values = pd.concat([original_values, processed_values])
883
+ bins = np.histogram_bin_edges(all_values, bins=30)
884
+
885
+ # Histograms
886
+ ax.hist(original_values, bins=bins, alpha=0.5,
887
+ label='Before Processing', density=True, color='blue')
888
+ ax.hist(processed_values, bins=bins, alpha=0.5,
889
+ label='After Processing', density=True, color='orange')
890
+
891
+ # Add KDE
892
+ try:
893
+ if len(original_values) > 10:
894
+ kde_original = gaussian_kde(original_values)
895
+ x_range = np.linspace(original_values.min(), original_values.max(), 100)
896
+ ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8)
897
+
898
+ if len(processed_values) > 10:
899
+ kde_processed = gaussian_kde(processed_values)
900
+ x_range = np.linspace(processed_values.min(), processed_values.max(), 100)
901
+ ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8)
902
+ except:
903
+ pass
904
+
905
+ # Add statistics
906
+ stats_text = []
907
+ if len(original_values) > 0:
908
+ stats_text.append(f"Before: μ={original_values.mean():.2f}, σ={original_values.std():.2f}")
909
+ if len(processed_values) > 0:
910
+ stats_text.append(f"After: μ={processed_values.mean():.2f}, σ={processed_values.std():.2f}")
911
+
912
+ if stats_text:
913
+ ax.text(0.02, 0.98, '\n'.join(stats_text),
914
+ transform=ax.transAxes, fontsize=8,
915
+ verticalalignment='top',
916
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
917
+
918
+ ax.set_title(f'{col}', fontsize=11, fontweight='bold')
919
+ ax.set_xlabel('Value', fontsize=9)
920
+ ax.set_ylabel('Density', fontsize=9)
921
+ ax.legend(fontsize=8)
922
+ ax.grid(True, alpha=0.3)
923
+ else:
924
+ ax.text(0.5, 0.5, 'No data',
925
+ ha='center', va='center', transform=ax.transAxes)
926
+ else:
927
+ ax.text(0.5, 0.5, 'Column not found',
928
+ ha='center', va='center', transform=ax.transAxes)
929
+
930
+ # Hide unused subplots
931
+ for idx in range(len(columns), len(axes)):
932
+ axes[idx].set_visible(False)
933
+
934
+ plt.tight_layout()
935
+
936
+ # Save
937
+ filepath = self._save_figure(fig, filename, "distributions")
938
+ self.plot_files['distribution_comparison'] = filepath
939
+ return filepath
940
+
941
+ except Exception as e:
942
+ logger.error(f"Error creating distribution comparison: {e}")
943
+ return None
944
+
945
+ def create_time_series_decomposition_plot(
946
+ self,
947
+ decomposition_result: Dict,
948
+ filename: str = "time_series_decomposition"
949
+ ) -> str:
950
+ """
951
+ Visualise time series decomposition
952
+
953
+ Parameters:
954
+ -----------
955
+ decomposition_result : Dict
956
+ Decomposition results
957
+ filename : str
958
+ Filename for saving
959
+
960
+ Returns:
961
+ --------
962
+ str : path to saved file or None if error
963
+ """
964
+ target_col = self.config.target_column
965
+
966
+ try:
967
+ fig, axes = plt.subplots(4, 1, figsize=(14, 10))
968
+ fig.suptitle(f'Time Series Decomposition: {target_col}',
969
+ fontsize=16, fontweight='bold', y=0.98)
970
+
971
+ # Original series
972
+ if 'observed' in decomposition_result:
973
+ observed = decomposition_result['observed']
974
+ axes[0].plot(observed, color='blue', linewidth=1.5)
975
+ axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold')
976
+ axes[0].grid(True, alpha=0.3)
977
+ axes[0].set_title('Original Time Series', fontsize=12)
978
+
979
+ # Trend
980
+ if 'trend' in decomposition_result and decomposition_result['trend'] is not None:
981
+ trend = decomposition_result['trend']
982
+ axes[1].plot(trend, color='red', linewidth=2)
983
+ axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold')
984
+ axes[1].grid(True, alpha=0.3)
985
+ axes[1].set_title('Trend Component', fontsize=12)
986
+
987
+ # Seasonality
988
+ if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None:
989
+ seasonal = decomposition_result['seasonal']
990
+ axes[2].plot(seasonal, color='green', linewidth=1.5)
991
+ axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold')
992
+ axes[2].grid(True, alpha=0.3)
993
+ axes[2].set_title('Seasonal Component', fontsize=12)
994
+
995
+ # Residuals
996
+ if 'residual' in decomposition_result and decomposition_result['residual'] is not None:
997
+ residual = decomposition_result['residual']
998
+ axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7)
999
+ axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold')
1000
+ axes[3].set_xlabel('Date', fontsize=11, fontweight='bold')
1001
+ axes[3].grid(True, alpha=0.3)
1002
+ axes[3].set_title('Residual Component', fontsize=12)
1003
+
1004
+ # Add residual statistics
1005
+ if len(residual) > 0:
1006
+ stats_text = (f"Mean: {residual.mean():.4f}\n"
1007
+ f"Std: {residual.std():.4f}\n"
1008
+ f"Min: {residual.min():.4f}\n"
1009
+ f"Max: {residual.max():.4f}")
1010
+ axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes,
1011
+ fontsize=8, verticalalignment='top',
1012
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
1013
+
1014
+ plt.tight_layout()
1015
+
1016
+ # Save
1017
+ filepath = self._save_figure(fig, filename, "time_series")
1018
+ self.plot_files['time_series_decomposition'] = filepath
1019
+ return filepath
1020
+
1021
+ except Exception as e:
1022
+ logger.error(f"Error creating time series decomposition: {e}")
1023
+ return None
1024
+
1025
+ def create_data_quality_report(
1026
+ self,
1027
+ validation_results: Dict,
1028
+ filename: str = "data_quality_report"
1029
+ ) -> str:
1030
+ """
1031
+ Create visual data quality report
1032
+
1033
+ Parameters:
1034
+ -----------
1035
+ validation_results : Dict
1036
+ Validation results
1037
+ filename : str
1038
+ Filename for saving
1039
+
1040
+ Returns:
1041
+ --------
1042
+ str : path to saved file or None if error
1043
+ """
1044
+ try:
1045
+ fig = plt.figure(figsize=(16, 12))
1046
+ fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98)
1047
+
1048
+ # Use GridSpec for more complex layout
1049
+ gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
1050
+
1051
+ # 1. Quality radar chart (top left)
1052
+ ax1 = fig.add_subplot(gs[0, 0], projection='polar')
1053
+
1054
+ categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness']
1055
+
1056
+ # Extract values from validation results
1057
+ if 'quality_metrics' in validation_results:
1058
+ values = [
1059
+ validation_results['quality_metrics'].get('size_score', 0.5),
1060
+ validation_results['quality_metrics'].get('missing_score', 0.5),
1061
+ validation_results['quality_metrics'].get('duplicates_score', 0.5),
1062
+ validation_results['quality_metrics'].get('stability_score', 0.5),
1063
+ validation_results['quality_metrics'].get('informativeness_score', 0.5)
1064
+ ]
1065
+ else:
1066
+ values = [0.8, 0.7, 0.9, 0.6, 0.8]
1067
+
1068
+ N = len(categories)
1069
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
1070
+ angles += angles[:1]
1071
+ values += values[:1]
1072
+
1073
+ ax1.plot(angles, values, 'o-', linewidth=2, color='blue')
1074
+ ax1.fill(angles, values, alpha=0.25, color='blue')
1075
+ ax1.set_xticks(angles[:-1])
1076
+ ax1.set_xticklabels(categories, fontsize=10)
1077
+ ax1.set_ylim(0, 1)
1078
+ ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold')
1079
+ ax1.grid(True)
1080
+
1081
+ # 2. Check status (top right)
1082
+ ax2 = fig.add_subplot(gs[0, 1])
1083
+
1084
+ basic_checks = validation_results.get('basic_checks', {})
1085
+ checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False))
1086
+ checks_total = len(basic_checks)
1087
+ checks_failed = checks_total - checks_passed
1088
+
1089
+ if checks_total > 0:
1090
+ colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B',
1091
+ '#FF6B6B' if checks_failed > 0 else '#4CAF50']
1092
+ bars = ax2.bar(['Passed', 'Failed'],
1093
+ [checks_passed, checks_failed],
1094
+ color=colors, edgecolor='black')
1095
+
1096
+ ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}',
1097
+ fontsize=12, fontweight='bold')
1098
+ ax2.set_ylabel('Number of Checks', fontsize=10)
1099
+ ax2.grid(True, alpha=0.3, axis='y')
1100
+
1101
+ # Add values on bars
1102
+ for bar, value in zip(bars, [checks_passed, checks_failed]):
1103
+ height = bar.get_height()
1104
+ ax2.text(bar.get_x() + bar.get_width()/2., height,
1105
+ f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
1106
+ else:
1107
+ ax2.text(0.5, 0.5, 'No check data available',
1108
+ ha='center', va='center', transform=ax2.transAxes)
1109
+ ax2.set_title('Basic Checks', fontsize=12, fontweight='bold')
1110
+
1111
+ # 3. Overall score (top right)
1112
+ ax3 = fig.add_subplot(gs[0, 2])
1113
+
1114
+ overall_score = validation_results.get('overall_score', 0)
1115
+ status = validation_results.get('status', 'UNKNOWN')
1116
+
1117
+ # Score pie chart
1118
+ sizes = [overall_score, 100 - overall_score]
1119
+
1120
+ if overall_score >= 80:
1121
+ colors = ['#4CAF50', '#E0E0E0'] # Green
1122
+ elif overall_score >= 60:
1123
+ colors = ['#FFC107', '#E0E0E0'] # Yellow
1124
+ else:
1125
+ colors = ['#F44336', '#E0E0E0'] # Red
1126
+
1127
+ wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90,
1128
+ autopct='%1.1f%%', pctdistance=0.85)
1129
+
1130
+ # Central text
1131
+ status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'}
1132
+ status_color = status_colors.get(status, '#757575')
1133
+
1134
+ ax3.text(0, 0, f'{overall_score}/100\n{status}',
1135
+ ha='center', va='center', fontsize=14, fontweight='bold',
1136
+ color=status_color)
1137
+ ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold')
1138
+
1139
+ # 4. Issue distribution by type (left middle)
1140
+ ax4 = fig.add_subplot(gs[1, 0])
1141
+
1142
+ issues = validation_results.get('issues', {})
1143
+ issue_counts = {
1144
+ 'Critical': len(issues.get('critical', [])),
1145
+ 'Warnings': len(issues.get('warning', [])),
1146
+ 'Informational': len(issues.get('info', []))
1147
+ }
1148
+
1149
+ if any(issue_counts.values()):
1150
+ colors = ['#F44336', '#FF9800', '#2196F3']
1151
+ bars = ax4.bar(issue_counts.keys(), issue_counts.values(),
1152
+ color=colors, edgecolor='black')
1153
+
1154
+ ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold')
1155
+ ax4.set_ylabel('Number of Issues', fontsize=10)
1156
+ ax4.tick_params(axis='x', rotation=45)
1157
+ ax4.grid(True, alpha=0.3, axis='y')
1158
+
1159
+ # Add values on bars
1160
+ for bar, value in zip(bars, issue_counts.values()):
1161
+ height = bar.get_height()
1162
+ ax4.text(bar.get_x() + bar.get_width()/2., height,
1163
+ f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
1164
+ else:
1165
+ ax4.text(0.5, 0.5, 'No issues detected',
1166
+ ha='center', va='center', transform=ax4.transAxes, fontsize=12)
1167
+ ax4.set_title('Data Issues', fontsize=12, fontweight='bold')
1168
+
1169
+ # 5. Detailed information (remaining cells)
1170
+ ax5 = fig.add_subplot(gs[1:, 1:])
1171
+ ax5.axis('off')
1172
+
1173
+ # Form text report
1174
+ report_text = []
1175
+ report_text.append("DETAILED REPORT:")
1176
+ report_text.append("=" * 40)
1177
+
1178
+ # Basic information
1179
+ report_text.append("\nBASIC INFORMATION:")
1180
+ report_text.append(f"• Overall score: {overall_score}/100")
1181
+ report_text.append(f"• Status: {status}")
1182
+ report_text.append(f"• Checks passed: {checks_passed}/{checks_total}")
1183
+
1184
+ # Check details
1185
+ if basic_checks:
1186
+ report_text.append("\nCHECK DETAILS:")
1187
+ for check_name, check_result in basic_checks.items():
1188
+ status_icon = "✓" if check_result.get('passed', False) else "✗"
1189
+ report_text.append(f"• {status_icon} {check_name}: {check_result.get('message', '')}")
1190
+
1191
+ # Issues
1192
+ if any(issue_counts.values()):
1193
+ report_text.append("\nDETECTED ISSUES:")
1194
+
1195
+ if issue_counts['Critical'] > 0:
1196
+ report_text.append("\nCRITICAL:")
1197
+ for issue in issues.get('critical', []):
1198
+ report_text.append(f" • {issue}")
1199
+
1200
+ if issue_counts['Warnings'] > 0:
1201
+ report_text.append("\nWARNINGS:")
1202
+ for issue in issues.get('warning', []):
1203
+ report_text.append(f" • {issue}")
1204
+
1205
+ if issue_counts['Informational'] > 0:
1206
+ report_text.append("\nINFORMATIONAL:")
1207
+ for issue in issues.get('info', []):
1208
+ report_text.append(f" • {issue}")
1209
+
1210
+ # Recommendations
1211
+ recommendations = validation_results.get('recommendations', [])
1212
+ if recommendations:
1213
+ report_text.append("\nRECOMMENDATIONS:")
1214
+ for i, rec in enumerate(recommendations, 1):
1215
+ report_text.append(f"{i}. {rec}")
1216
+
1217
+ ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes,
1218
+ fontsize=9, verticalalignment='top', fontfamily='monospace',
1219
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1))
1220
+
1221
+ plt.tight_layout()
1222
+
1223
+ # Save
1224
+ filepath = self._save_figure(fig, filename, "reports")
1225
+ self.plot_files['data_quality_report'] = filepath
1226
+ return filepath
1227
+
1228
+ except Exception as e:
1229
+ logger.error(f"Error creating data quality report: {e}")
1230
+ return None
1231
+
1232
+ # ============================================
1233
+ # METHODS FOR BATCH SAVING
1234
+ # ============================================
1235
+
1236
+ def save_all_preprocessing_plots(self) -> Dict[str, str]:
1237
+ """
1238
+ Save all preprocessing plots from current session
1239
+
1240
+ Returns:
1241
+ --------
1242
+ Dict[str, str] : dictionary with paths to saved plots
1243
+ """
1244
+ logger.info("Saving all preprocessing plots...")
1245
+
1246
+ plots_saved = {}
1247
+
1248
+ # Get all open figures
1249
+ figure_numbers = plt.get_fignums()
1250
+
1251
+ if not figure_numbers:
1252
+ logger.warning("No open plots to save")
1253
+ return plots_saved
1254
+
1255
+ # Save each plot
1256
+ for fig_num in figure_numbers:
1257
+ fig = plt.figure(fig_num)
1258
+ filename = f"preprocessing_plot_{fig_num}.png"
1259
+ filepath = self._save_figure(fig, filename, "preprocessing")
1260
+ if filepath:
1261
+ plots_saved[f"plot_{fig_num}"] = filepath
1262
+
1263
+ logger.info(f"Saved {len(plots_saved)} preprocessing plots")
1264
+ return plots_saved
1265
+
1266
+ def create_all_visualizations(
1267
+ self,
1268
+ data: pd.DataFrame,
1269
+ processed_data: pd.DataFrame = None,
1270
+ feature_importance: Dict = None,
1271
+ decomposition_result: Dict = None,
1272
+ validation_results: Dict = None,
1273
+ preprocessing_stages: Dict = None
1274
+ ) -> Dict[str, str]:
1275
+ """
1276
+ Create all visualisations in one call
1277
+
1278
+ Parameters:
1279
+ -----------
1280
+ data : pd.DataFrame
1281
+ Original data
1282
+ processed_data : pd.DataFrame, optional
1283
+ Processed data
1284
+ feature_importance : Dict, optional
1285
+ Feature importance
1286
+ decomposition_result : Dict, optional
1287
+ Decomposition results
1288
+ validation_results : Dict, optional
1289
+ Validation results
1290
+ preprocessing_stages : Dict, optional
1291
+ Preprocessing stages
1292
+
1293
+ Returns:
1294
+ --------
1295
+ Dict[str, str] : dictionary with paths to created plots
1296
+ """
1297
+ logger.info("\n" + "="*80)
1298
+ logger.info("STARTING ALL VISUALISATIONS CREATION")
1299
+ logger.info("="*80)
1300
+
1301
+ result_files = {}
1302
+
1303
+ # 1. Summary dashboard
1304
+ if data is not None:
1305
+ logger.info("Creating summary dashboard...")
1306
+ summary_path = self.create_summary_dashboard(data, preprocessing_stages)
1307
+ if summary_path:
1308
+ result_files['summary'] = summary_path
1309
+
1310
+ # 2. Correlation heatmaps
1311
+ if data is not None:
1312
+ logger.info("Creating correlation heatmaps...")
1313
+ main_corr, target_corr = self.create_correlation_heatmap(data)
1314
+ if main_corr:
1315
+ result_files['correlation_main'] = main_corr
1316
+ if target_corr:
1317
+ result_files['correlation_target'] = target_corr
1318
+
1319
+ # 3. Distribution comparison
1320
+ if data is not None and processed_data is not None:
1321
+ logger.info("Creating distribution comparison...")
1322
+ dist_path = self.create_distribution_comparison(data, processed_data)
1323
+ if dist_path:
1324
+ result_files['distribution'] = dist_path
1325
+
1326
+ # 4. Feature importance
1327
+ if feature_importance:
1328
+ logger.info("Creating feature importance plot...")
1329
+ feat_path = self.create_feature_importance_plot(feature_importance)
1330
+ if feat_path:
1331
+ result_files['feature_importance'] = feat_path
1332
+
1333
+ # 5. Time series decomposition
1334
+ if decomposition_result:
1335
+ logger.info("Creating time series decomposition...")
1336
+ decomp_path = self.create_time_series_decomposition_plot(decomposition_result)
1337
+ if decomp_path:
1338
+ result_files['decomposition'] = decomp_path
1339
+
1340
+ # 6. Data quality report
1341
+ if validation_results:
1342
+ logger.info("Creating data quality report...")
1343
+ quality_path = self.create_data_quality_report(validation_results)
1344
+ if quality_path:
1345
+ result_files['quality_report'] = quality_path
1346
+
1347
+ # Save information about all plots
1348
+ self.save_plots_info()
1349
+
1350
+ logger.info("\n" + "="*80)
1351
+ logger.info("VISUALISATIONS SUCCESSFULLY CREATED")
1352
+ logger.info("="*80)
1353
+
1354
+ for plot_name, plot_path in result_files.items():
1355
+ if plot_path:
1356
+ logger.info(f"✓ {plot_name}: {plot_path}")
1357
+
1358
+ return result_files
1359
+
1360
+ def get_all_plots(self) -> Dict:
1361
+ """Get information about all created plots"""
1362
+ return self.plot_files
1363
+
1364
+ def save_plots_info(self, filename: str = "plots_info.json") -> None:
1365
+ """Save plot information to JSON file"""
1366
+ try:
1367
+ plots_info = {
1368
+ 'total_plots': len(self.plot_files),
1369
+ 'plots': self.plot_files,
1370
+ 'directories': {
1371
+ 'correlations': self.correlations_dir,
1372
+ 'distributions': self.distributions_dir,
1373
+ 'features': self.features_dir,
1374
+ 'time_series': self.time_series_dir,
1375
+ 'preprocessing': self.preprocessing_dir,
1376
+ 'summary': self.summary_dir,
1377
+ 'reports': self.reports_dir
1378
+ },
1379
+ 'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
1380
+ 'config': {
1381
+ 'target_column': self.config.target_column,
1382
+ 'results_dir': self.config.results_dir
1383
+ }
1384
+ }
1385
+
1386
+ filepath = os.path.join(self.reports_dir, filename)
1387
+
1388
+ with open(filepath, 'w', encoding='utf-8') as f:
1389
+ json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str)
1390
+
1391
+ logger.info(f"✓ Plot information saved: {filepath}")
1392
+
1393
+ except Exception as e:
1394
+ logger.error(f"✗ Error saving plot information: {e}")
1395
+
1396
+ def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]:
1397
+ """
1398
+ Move existing plots from specified directory to structured folders
1399
+
1400
+ Parameters:
1401
+ -----------
1402
+ source_dir : str, optional
1403
+ Directory with existing plots
1404
+
1405
+ Returns:
1406
+ --------
1407
+ Dict[str, str] : dictionary with information about moved files
1408
+ """
1409
+ if source_dir is None:
1410
+ source_dir = self.plots_dir
1411
+
1412
+ if not os.path.exists(source_dir):
1413
+ logger.warning(f"Source directory doesn't exist: {source_dir}")
1414
+ return {}
1415
+
1416
+ # File to folder mapping
1417
+ file_to_folder_map = {
1418
+ # Time series
1419
+ 'data_split.png': 'time_series',
1420
+ 'stationarity_raskhodvoda.png': 'time_series',
1421
+ 'stationarity_analysis.png': 'time_series',
1422
+ 'temporal_outliers.png': 'time_series',
1423
+
1424
+ # Correlations
1425
+ 'feature_selection_correlation.png': 'correlations',
1426
+
1427
+ # Preprocessing
1428
+ 'missing_values_analysis.png': 'preprocessing',
1429
+ 'outlier_handling_results.png': 'preprocessing',
1430
+ 'outliers_analysis.png': 'preprocessing',
1431
+ 'scaling_results.png': 'preprocessing',
1432
+
1433
+ # Default
1434
+ 'default': 'summary'
1435
+ }
1436
+
1437
+ moved_files = {}
1438
+
1439
+ for filename in os.listdir(source_dir):
1440
+ if filename.endswith('.png'):
1441
+ source_path = os.path.join(source_dir, filename)
1442
+
1443
+ # Determine destination folder
1444
+ target_folder = file_to_folder_map.get(filename, file_to_folder_map['default'])
1445
+ target_dir = os.path.join(self.plots_dir, target_folder)
1446
+
1447
+ # Create destination folder if doesn't exist
1448
+ os.makedirs(target_dir, exist_ok=True)
1449
+
1450
+ # Target path
1451
+ target_path = os.path.join(target_dir, filename)
1452
+
1453
+ try:
1454
+ # Move file
1455
+ os.rename(source_path, target_path)
1456
+ moved_files[filename] = target_path
1457
+ logger.info(f"Moved: {filename} -> {target_folder}/")
1458
+ except Exception as e:
1459
+ logger.error(f"Error moving {filename}: {e}")
1460
+
1461
+ logger.info(f"Moved {len(moved_files)} files")
1462
+ return moved_files