Spaces:
Runtime error
Runtime error
Update all files
Browse files- .gitignore +9 -0
- Dockerfile +1 -1
- README.md +262 -14
- app.py +0 -0
- config/__init__.py +0 -0
- config/config.py +169 -0
- config/default_config.json +78 -0
- config/settings.py +375 -0
- correlations/__init__.py +0 -0
- correlations/correlation_analyzer.py +687 -0
- data_loader/__init__.py +0 -0
- data_loader/data_loader.py +487 -0
- decomposition/__init__.py +0 -0
- decomposition/decomposer.py +690 -0
- feature_selection/__init__.py +0 -0
- feature_selection/feature_selector.py +478 -0
- features/__init__.py +0 -0
- features/feature_engineer.py +638 -0
- missing_values/__init__.py +0 -0
- missing_values/missing_analyzer.py +700 -0
- outliers/__init__.py +0 -0
- outliers/outlier_analyzer.py +857 -0
- pipeline/__init__.py +0 -0
- pipeline/main_pipeline.py +603 -0
- requirements.txt +100 -3
- run_pipeline.py +62 -0
- scaling/__init__.py +0 -0
- scaling/data_scaler.py +634 -0
- splitting/__init__.py +0 -0
- splitting/data_splitter.py +403 -0
- src/streamlit_app.py +0 -40
- stationarity/__init__.py +0 -0
- stationarity/stationarity_checker.py +631 -0
- streamlit/streamlit_app.py +0 -0
- temp_data.csv +0 -0
- validation/__init__.py +0 -0
- validation/data_validator.py +655 -0
- visualization/__init__.py +0 -0
- visualization/visualization_manager.py +1462 -0
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__/
|
| 4 |
+
*__pycache__/
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
.Python
|
| 9 |
+
streamlit_results/
|
Dockerfile
CHANGED
|
@@ -17,4 +17,4 @@ EXPOSE 8501
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "
|
|
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "streamlit/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
CHANGED
|
@@ -1,20 +1,268 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
pinned: false
|
| 11 |
-
short_description: TimeFlowPro
|
| 12 |
-
license: mit
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TimeFlow Pro
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
pinned: true
|
| 8 |
+
app_file: app.py
|
| 9 |
+
sdk_version: 1.52.2
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 📊 TimeFlow Pro
|
| 13 |
|
| 14 |
+
<div align="center">
|
| 15 |
|
| 16 |
+
**Intelligent Time Series Data Analysis and Preprocessing Platform**
|
| 17 |
+
|
| 18 |
+
*Advanced pipeline for data preparation and feature engineering*
|
| 19 |
+
|
| 20 |
+
[](https://huggingface.co/spaces/your-username/timeflow-pro)
|
| 21 |
+
[](https://streamlit.io)
|
| 22 |
+
[](https://python.org)
|
| 23 |
+
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
## 🌟 Overview
|
| 27 |
+
|
| 28 |
+
TimeFlow Pro is a comprehensive platform for time series data analysis, preprocessing, and feature engineering. Designed for data scientists and analysts, it provides an intuitive interface for transforming raw time series data into ML-ready datasets with advanced preprocessing capabilities.
|
| 29 |
+
|
| 30 |
+
## 🚀 Key Features
|
| 31 |
+
|
| 32 |
+
### 📈 **Data Analysis & Visualization**
|
| 33 |
+
- **Interactive Data Exploration**: Real-time preview and statistics
|
| 34 |
+
- **Missing Value Analysis**: Smart detection and handling strategies
|
| 35 |
+
- **Outlier Detection**: Multiple methods including IQR, Z-Score, Isolation Forest
|
| 36 |
+
- **Temporal Analysis**: Seasonality detection, trend analysis, decomposition
|
| 37 |
+
|
| 38 |
+
### ⚙️ **Advanced Preprocessing Pipeline**
|
| 39 |
+
- **Feature Engineering**: Automatic lag features, rolling statistics, seasonal components
|
| 40 |
+
- **Stationarity Checking**: ADF tests and transformation suggestions
|
| 41 |
+
- **Data Scaling**: Robust, Standard, MinMax, and custom scaling methods
|
| 42 |
+
- **Feature Selection**: Correlation, variance, mutual information, RF importance
|
| 43 |
+
|
| 44 |
+
### 🏗️ **ML-Ready Outputs**
|
| 45 |
+
- **Train/Validation/Test Splits**: Time-based or random splitting
|
| 46 |
+
- **Multiple Export Formats**: CSV, Parquet, Excel, JSON
|
| 47 |
+
- **Model Integration**: Ready-to-use datasets for scikit-learn, XGBoost, LightGBM
|
| 48 |
+
- **Visual Reports**: Comprehensive pipeline execution reports
|
| 49 |
+
|
| 50 |
+
## 🎮 Quick Start
|
| 51 |
+
|
| 52 |
+
### 1. **Upload Your Data**
|
| 53 |
+
- Support for CSV, Excel, Parquet formats
|
| 54 |
+
- Automatic date parsing and validation
|
| 55 |
+
- Smart column type detection
|
| 56 |
+
|
| 57 |
+
### 2. **Configure Pipeline**
|
| 58 |
+
```python
|
| 59 |
+
# Example configuration
|
| 60 |
+
config = {
|
| 61 |
+
'target_column': 'sales',
|
| 62 |
+
'test_size': 0.2,
|
| 63 |
+
'max_lags': 5,
|
| 64 |
+
'seasonal_period': 365,
|
| 65 |
+
'scaling_method': 'robust'
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### 3. **Run Pipeline & Export**
|
| 70 |
+
- Execute full preprocessing pipeline
|
| 71 |
+
- Download processed data
|
| 72 |
+
- Get feature importance reports
|
| 73 |
+
- Export modeling datasets
|
| 74 |
+
|
| 75 |
+
## 📊 Technical Architecture
|
| 76 |
+
|
| 77 |
+
### 🔧 **Pipeline Components**
|
| 78 |
+
```
|
| 79 |
+
Data Loading → Validation → Missing Handling → Outlier Treatment
|
| 80 |
+
↓
|
| 81 |
+
Feature Engineering → Stationarity Check → Correlation Analysis
|
| 82 |
+
↓
|
| 83 |
+
Data Splitting → Scaling → Feature Selection → Final Validation
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### 🏆 **Core Features**
|
| 87 |
+
- **Multi-stage Validation**: Raw, processed, and final data validation
|
| 88 |
+
- **Memory Optimization**: Efficient handling of large datasets
|
| 89 |
+
- **Error Recovery**: Graceful handling of pipeline failures
|
| 90 |
+
- **Reproducible Results**: Configuration saving and logging
|
| 91 |
+
|
| 92 |
+
## 📚 Use Cases
|
| 93 |
+
|
| 94 |
+
### 🏢 **Business Analytics**
|
| 95 |
+
- Sales forecasting and trend analysis
|
| 96 |
+
- Inventory optimization
|
| 97 |
+
- Customer behavior prediction
|
| 98 |
+
- Financial time series analysis
|
| 99 |
+
|
| 100 |
+
### 🏭 **Industrial Applications**
|
| 101 |
+
- Sensor data preprocessing
|
| 102 |
+
- Predictive maintenance
|
| 103 |
+
- Quality control monitoring
|
| 104 |
+
- Energy consumption forecasting
|
| 105 |
+
|
| 106 |
+
### 🎓 **Academic Research**
|
| 107 |
+
- Time series modeling experiments
|
| 108 |
+
- Feature engineering research
|
| 109 |
+
- Algorithm comparison studies
|
| 110 |
+
- Educational tool for data science
|
| 111 |
+
|
| 112 |
+
## 🛠️ Installation
|
| 113 |
+
|
| 114 |
+
### Local Development
|
| 115 |
+
```bash
|
| 116 |
+
# Clone repository
|
| 117 |
+
git clone https://huggingface.co/spaces/your-username/timeflow-pro
|
| 118 |
+
cd timeflow-pro
|
| 119 |
+
|
| 120 |
+
# Install dependencies
|
| 121 |
+
pip install -r requirements.txt
|
| 122 |
+
|
| 123 |
+
# Run application
|
| 124 |
+
streamlit run app.py
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Docker Deployment
|
| 128 |
+
```bash
|
| 129 |
+
# Build Docker image
|
| 130 |
+
docker build -t timeflow-pro .
|
| 131 |
+
|
| 132 |
+
# Run container
|
| 133 |
+
docker run -p 8501:8501 timeflow-pro
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## 🌐 API Usage Example
|
| 137 |
+
|
| 138 |
+
```python
|
| 139 |
+
from timeflow_pro import TimeFlowPipeline
|
| 140 |
+
import pandas as pd
|
| 141 |
+
|
| 142 |
+
# Load your data
|
| 143 |
+
data = pd.read_csv('your_data.csv')
|
| 144 |
+
|
| 145 |
+
# Configure pipeline
|
| 146 |
+
config = {
|
| 147 |
+
'target_column': 'target',
|
| 148 |
+
'test_size': 0.2,
|
| 149 |
+
'max_lags': 7,
|
| 150 |
+
'seasonal_period': 30
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Create and run pipeline
|
| 154 |
+
pipeline = TimeFlowPipeline(config)
|
| 155 |
+
processed_data = pipeline.run(data)
|
| 156 |
+
|
| 157 |
+
# Get modeling data
|
| 158 |
+
modeling_data = pipeline.get_modeling_data()
|
| 159 |
+
X_train, y_train = modeling_data['X_train'], modeling_data['y_train']
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## 📈 Performance Benchmarks
|
| 163 |
+
|
| 164 |
+
| Dataset Size | Processing Time | Memory Usage | Features Generated |
|
| 165 |
+
|--------------|----------------|--------------|-------------------|
|
| 166 |
+
| 10K rows | ~5 seconds | <500 MB | 50-100 features |
|
| 167 |
+
| 100K rows | ~30 seconds | <1 GB | 100-200 features |
|
| 168 |
+
| 1M rows | ~5 minutes | <2 GB | 200-500 features |
|
| 169 |
+
|
| 170 |
+
## 🔧 Configuration Options
|
| 171 |
+
|
| 172 |
+
### **Data Processing**
|
| 173 |
+
- `missing_threshold`: Threshold for column removal (0.0-0.5)
|
| 174 |
+
- `outlier_method`: IQR, Z-Score, or Isolation Forest
|
| 175 |
+
- `scaling_method`: Robust, Standard, MinMax, or None
|
| 176 |
+
|
| 177 |
+
### **Feature Engineering**
|
| 178 |
+
- `max_lags`: Maximum lag features (1-20)
|
| 179 |
+
- `seasonal_period`: Seasonal window (7, 30, 90, 365)
|
| 180 |
+
- `rolling_windows`: List of rolling windows [7, 30, 90]
|
| 181 |
+
|
| 182 |
+
### **Model Preparation**
|
| 183 |
+
- `feature_selection_method`: Correlation, Variance, RF, Mutual Info
|
| 184 |
+
- `max_features`: Maximum features to select (5-100)
|
| 185 |
+
- `split_method`: Time-based or random splitting
|
| 186 |
+
|
| 187 |
+
## 📋 Requirements
|
| 188 |
+
|
| 189 |
+
### **Core Dependencies**
|
| 190 |
+
```txt
|
| 191 |
+
streamlit>=1.28.0
|
| 192 |
+
pandas>=2.0.0
|
| 193 |
+
numpy>=1.24.0
|
| 194 |
+
plotly>=5.17.0
|
| 195 |
+
scikit-learn>=1.3.0
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
### **Optional Dependencies**
|
| 199 |
+
```txt
|
| 200 |
+
xgboost>=2.0.0 # For XGBoost feature importance
|
| 201 |
+
lightgbm>=4.0.0 # For LightGBM integration
|
| 202 |
+
statsmodels>=0.14.0 # For advanced time series analysis
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## 🤝 Contributing
|
| 206 |
+
|
| 207 |
+
We welcome contributions! Here's how you can help:
|
| 208 |
+
|
| 209 |
+
### **Areas for Contribution**
|
| 210 |
+
1. **New Feature Engineering Methods**
|
| 211 |
+
2. **Additional Visualization Types**
|
| 212 |
+
3. **Export Format Support**
|
| 213 |
+
4. **Performance Optimizations**
|
| 214 |
+
5. **Documentation Improvements**
|
| 215 |
+
|
| 216 |
+
### **Development Workflow**
|
| 217 |
+
```bash
|
| 218 |
+
# 1. Fork the repository
|
| 219 |
+
# 2. Create feature branch
|
| 220 |
+
git checkout -b feature/new-feature
|
| 221 |
+
|
| 222 |
+
# 3. Make changes and test
|
| 223 |
+
# 4. Submit pull request
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
## 📜 License
|
| 227 |
+
|
| 228 |
+
This project is licensed under the **MIT License** - see the [LICENSE](LICENSE) file for details.
|
| 229 |
+
|
| 230 |
+
## 🙏 Acknowledgments
|
| 231 |
+
|
| 232 |
+
### **Special Thanks To:**
|
| 233 |
+
- **Streamlit Team** for the amazing framework
|
| 234 |
+
- **Hugging Face** for hosting the Space
|
| 235 |
+
- **Open Source Community** for invaluable libraries
|
| 236 |
+
- **All Contributors** who helped improve TimeFlow Pro
|
| 237 |
+
|
| 238 |
+
### **Built With:**
|
| 239 |
+
- 🐍 Python
|
| 240 |
+
- 📊 Streamlit
|
| 241 |
+
- 🎨 Plotly
|
| 242 |
+
- 🔧 Scikit-learn
|
| 243 |
+
- 📈 Pandas & NumPy
|
| 244 |
+
|
| 245 |
+
## 📞 Support & Contact
|
| 246 |
+
|
| 247 |
+
### **Get Help:**
|
| 248 |
+
- 📧 **Email**: cool.araby@gmail.com
|
| 249 |
+
- 💬 **Issues**: [GitHub Issues](https://github.com/your-username/timeflow-pro/issues)
|
| 250 |
+
- 💡 **Discussions**: [Community Forum](https://github.com/your-username/timeflow-pro/discussions)
|
| 251 |
+
|
| 252 |
+
### **Stay Updated:**
|
| 253 |
+
- ⭐ **Star** the repository
|
| 254 |
+
- 👁️ **Watch** for releases
|
| 255 |
+
- 🔔 **Enable notifications**
|
| 256 |
+
|
| 257 |
+
---
|
| 258 |
+
|
| 259 |
+
<div align="center">
|
| 260 |
+
|
| 261 |
+
**Transform Your Time Series Data with Ease**
|
| 262 |
+
|
| 263 |
+
*TimeFlow Pro - Making Data Preparation Simple and Powerful*
|
| 264 |
+
|
| 265 |
+
[](https://huggingface.co/your-username)
|
| 266 |
+
[](https://github.com/your-username/timeflow-pro)
|
| 267 |
+
|
| 268 |
+
</div>
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config/__init__.py
ADDED
|
File without changes
|
config/config.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# ENUMERATION CLASSES
|
| 3 |
+
# ============================================
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
from venv import logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DataType(Enum):
|
| 14 |
+
"""Data types"""
|
| 15 |
+
NUMERIC = "numeric"
|
| 16 |
+
CATEGORICAL = "categorical"
|
| 17 |
+
TEMPORAL = "temporal"
|
| 18 |
+
TEXT = "text"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PreprocessingMethod(Enum):
|
| 22 |
+
"""Data preprocessing methods"""
|
| 23 |
+
FILL_MEAN = "fill_mean"
|
| 24 |
+
FILL_MEDIAN = "fill_median"
|
| 25 |
+
FILL_INTERPOLATE = "fill_interpolate"
|
| 26 |
+
FILL_KNN = "fill_knn"
|
| 27 |
+
REMOVE = "remove"
|
| 28 |
+
CLIP = "clip"
|
| 29 |
+
WINSORIZE = "winsorize"
|
| 30 |
+
NORMALIZE = "normalize"
|
| 31 |
+
STANDARDIZE = "standardize"
|
| 32 |
+
LOG_TRANSFORM = "log_transform"
|
| 33 |
+
BOX_COX = "box_cox"
|
| 34 |
+
DIFFERENCING = "differencing"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SeasonalityType(Enum):
|
| 38 |
+
"""Seasonality types"""
|
| 39 |
+
DAILY = "daily"
|
| 40 |
+
WEEKLY = "weekly"
|
| 41 |
+
MONTHLY = "monthly"
|
| 42 |
+
QUARTERLY = "quarterly"
|
| 43 |
+
YEARLY = "yearly"
|
| 44 |
+
MULTIPLE = "multiple"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ============================================
|
| 48 |
+
# CLASS 1: CONFIGURATION
|
| 49 |
+
# ============================================
|
| 50 |
+
@dataclass
|
| 51 |
+
class Config:
|
| 52 |
+
"""Experiment configuration for data preprocessing"""
|
| 53 |
+
|
| 54 |
+
# Paths and directories
|
| 55 |
+
data_path: str = 'temp_data.csv'
|
| 56 |
+
results_dir: str = 'data_preprocessing_results'
|
| 57 |
+
|
| 58 |
+
# Temporal parameters
|
| 59 |
+
start_year: int = 1970
|
| 60 |
+
end_year: int = 1990
|
| 61 |
+
freq: str = 'D' # Data frequency: D (daily), H (hourly), M (monthly)
|
| 62 |
+
|
| 63 |
+
# Target variable
|
| 64 |
+
target_column: str = 'raskhodvoda'
|
| 65 |
+
|
| 66 |
+
# Feature parameters
|
| 67 |
+
max_lags: int = 12
|
| 68 |
+
seasonal_period: int = 365
|
| 69 |
+
rolling_windows: List[int] = field(default_factory=lambda: [7, 30, 90, 365])
|
| 70 |
+
expanding_windows: List[int] = field(default_factory=lambda: [30, 90, 365])
|
| 71 |
+
|
| 72 |
+
# Processing parameters
|
| 73 |
+
missing_threshold: float = 0.3 # Threshold for dropping columns with missing values
|
| 74 |
+
outlier_method: str = 'iqr' # Outlier detection method: iqr, zscore, lof
|
| 75 |
+
outlier_alpha: float = 1.5 # IQR multiplier
|
| 76 |
+
outlier_contamination: float = 0.1 # For methods like LOF
|
| 77 |
+
|
| 78 |
+
# Data splitting
|
| 79 |
+
test_size: float = 0.2
|
| 80 |
+
validation_size: float = 0.1
|
| 81 |
+
split_method: str = 'time' # time, random, expanding_window
|
| 82 |
+
|
| 83 |
+
# Scaling
|
| 84 |
+
scaling_method: str = 'robust' # standard, minmax, robust, none
|
| 85 |
+
|
| 86 |
+
# Feature selection
|
| 87 |
+
feature_selection_method: str = 'correlation' # correlation, mutual_info, rf, pca
|
| 88 |
+
max_features: int = 50
|
| 89 |
+
|
| 90 |
+
# Validation
|
| 91 |
+
enable_validation: bool = True
|
| 92 |
+
validation_rules: Dict = field(default_factory=dict)
|
| 93 |
+
|
| 94 |
+
# Visualisation
|
| 95 |
+
save_plots: bool = True
|
| 96 |
+
plot_style: str = 'seaborn'
|
| 97 |
+
|
| 98 |
+
# Performance
|
| 99 |
+
use_multiprocessing: bool = False
|
| 100 |
+
n_jobs: int = -1
|
| 101 |
+
chunk_size: int = 10000
|
| 102 |
+
|
| 103 |
+
# Logging
|
| 104 |
+
log_level: str = 'INFO'
|
| 105 |
+
save_reports: bool = True
|
| 106 |
+
|
| 107 |
+
def __post_init__(self):
|
| 108 |
+
"""Post-initialisation for creating directories and setting up logging"""
|
| 109 |
+
self.create_directories()
|
| 110 |
+
self.setup_logging()
|
| 111 |
+
|
| 112 |
+
# Setting default validation rules
|
| 113 |
+
if not self.validation_rules:
|
| 114 |
+
self.validation_rules = {
|
| 115 |
+
'min_rows': 100,
|
| 116 |
+
'max_missing_percentage': 30,
|
| 117 |
+
'min_unique_values': 2,
|
| 118 |
+
'max_skewness': 3,
|
| 119 |
+
'max_kurtosis': 10
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
def create_directories(self) -> None:
|
| 123 |
+
"""Create directories for preprocessing results"""
|
| 124 |
+
dirs = [
|
| 125 |
+
self.results_dir,
|
| 126 |
+
f'{self.results_dir}/plots',
|
| 127 |
+
f'{self.results_dir}/plots/time_series',
|
| 128 |
+
f'{self.results_dir}/plots/distributions',
|
| 129 |
+
f'{self.results_dir}/plots/correlations',
|
| 130 |
+
f'{self.results_dir}/plots/features',
|
| 131 |
+
f'{self.results_dir}/tables',
|
| 132 |
+
f'{self.results_dir}/processed_data',
|
| 133 |
+
f'{self.results_dir}/models',
|
| 134 |
+
f'{self.results_dir}/reports',
|
| 135 |
+
f'{self.results_dir}/logs',
|
| 136 |
+
f'{self.results_dir}/checkpoints'
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
for directory in dirs:
|
| 140 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
logger.info(f"Directories created in {self.results_dir}")
|
| 143 |
+
|
| 144 |
+
def setup_logging(self) -> None:
|
| 145 |
+
"""Configure logging"""
|
| 146 |
+
log_level = getattr(logging, self.log_level.upper())
|
| 147 |
+
logger.setLevel(log_level)
|
| 148 |
+
|
| 149 |
+
def to_dict(self) -> Dict:
|
| 150 |
+
"""Convert configuration to dictionary"""
|
| 151 |
+
return asdict(self)
|
| 152 |
+
|
| 153 |
+
def save(self, path: Optional[str] = None) -> None:
|
| 154 |
+
"""Save configuration to file"""
|
| 155 |
+
if path is None:
|
| 156 |
+
path = f'{self.results_dir}/config.json'
|
| 157 |
+
|
| 158 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 159 |
+
json.dump(self.to_dict(), f, indent=4, ensure_ascii=False)
|
| 160 |
+
|
| 161 |
+
logger.info(f"Configuration saved to {path}")
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def load(cls, path: str) -> 'Config':
|
| 165 |
+
"""Load configuration from file"""
|
| 166 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 167 |
+
config_dict = json.load(f)
|
| 168 |
+
|
| 169 |
+
return cls(**config_dict)
|
config/default_config.json
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"data_path": "temp_data.csv",
|
| 3 |
+
"results_dir": "results",
|
| 4 |
+
|
| 5 |
+
"start_year": 1970,
|
| 6 |
+
"end_year": 1990,
|
| 7 |
+
"freq": "D",
|
| 8 |
+
|
| 9 |
+
"target_column": "raskhodvoda",
|
| 10 |
+
|
| 11 |
+
"max_lags": 12,
|
| 12 |
+
"seasonal_period": 365,
|
| 13 |
+
"rolling_windows": [7, 30, 90, 365],
|
| 14 |
+
"expanding_windows": [30, 90, 365],
|
| 15 |
+
|
| 16 |
+
"missing_threshold": 0.3,
|
| 17 |
+
"outlier_method": "iqr",
|
| 18 |
+
"outlier_alpha": 1.5,
|
| 19 |
+
"outlier_contamination": 0.1,
|
| 20 |
+
|
| 21 |
+
"test_size": 0.2,
|
| 22 |
+
"validation_size": 0.1,
|
| 23 |
+
"split_method": "time",
|
| 24 |
+
|
| 25 |
+
"scaling_method": "robust",
|
| 26 |
+
|
| 27 |
+
"feature_selection_method": "correlation",
|
| 28 |
+
"max_features": 50,
|
| 29 |
+
|
| 30 |
+
"enable_validation": true,
|
| 31 |
+
"validation_rules": {
|
| 32 |
+
"min_rows": 100,
|
| 33 |
+
"max_missing_percentage": 30,
|
| 34 |
+
"min_unique_values": 2,
|
| 35 |
+
"max_skewness": 3,
|
| 36 |
+
"max_kurtosis": 10,
|
| 37 |
+
"min_variance": 0.001,
|
| 38 |
+
"max_constant_columns": 0
|
| 39 |
+
},
|
| 40 |
+
|
| 41 |
+
"save_plots": true,
|
| 42 |
+
"plot_style": "seaborn-whitegrid",
|
| 43 |
+
"plot_dpi": 300,
|
| 44 |
+
"plot_format": "png",
|
| 45 |
+
|
| 46 |
+
"use_multiprocessing": false,
|
| 47 |
+
"n_jobs": -1,
|
| 48 |
+
"chunk_size": 10000,
|
| 49 |
+
"memory_limit_gb": 4,
|
| 50 |
+
|
| 51 |
+
"log_level": "INFO",
|
| 52 |
+
"save_reports": true,
|
| 53 |
+
"report_format": "json",
|
| 54 |
+
|
| 55 |
+
"decomposition_method": "stl",
|
| 56 |
+
"stationarity_tests": ["adf", "kpss"],
|
| 57 |
+
"correlation_threshold": 0.85,
|
| 58 |
+
"vif_threshold": 10,
|
| 59 |
+
|
| 60 |
+
"random_seed": 42,
|
| 61 |
+
"enable_profiling": false,
|
| 62 |
+
"save_intermediate": true,
|
| 63 |
+
|
| 64 |
+
"streamlit_settings": {
|
| 65 |
+
"theme": "light",
|
| 66 |
+
"sidebar_state": "expanded",
|
| 67 |
+
"page_title": "Time Series Preprocessing",
|
| 68 |
+
"page_icon": "📊",
|
| 69 |
+
"layout": "wide"
|
| 70 |
+
},
|
| 71 |
+
|
| 72 |
+
"export_options": {
|
| 73 |
+
"csv": true,
|
| 74 |
+
"parquet": false,
|
| 75 |
+
"excel": false,
|
| 76 |
+
"pickle": true
|
| 77 |
+
}
|
| 78 |
+
}
|
config/settings.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
General project settings: visualisation, paths, constants
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
import yaml
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# PATHS AND DIRECTORIES
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 19 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 20 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 21 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 22 |
+
EXTERNAL_DATA_DIR = DATA_DIR / "external"
|
| 23 |
+
|
| 24 |
+
RESULTS_DIR = PROJECT_ROOT / "results"
|
| 25 |
+
PLOTS_DIR = RESULTS_DIR / "plots"
|
| 26 |
+
MODELS_DIR = RESULTS_DIR / "models"
|
| 27 |
+
REPORTS_DIR = RESULTS_DIR / "reports"
|
| 28 |
+
LOGS_DIR = RESULTS_DIR / "logs"
|
| 29 |
+
|
| 30 |
+
CONFIGS_DIR = PROJECT_ROOT / "configs"
|
| 31 |
+
NOTEBOOKS_DIR = PROJECT_ROOT / "notebooks"
|
| 32 |
+
TESTS_DIR = PROJECT_ROOT / "tests"
|
| 33 |
+
|
| 34 |
+
# Create directories on import
|
| 35 |
+
for directory in [RAW_DATA_DIR, PROCESSED_DATA_DIR, EXTERNAL_DATA_DIR,
|
| 36 |
+
PLOTS_DIR, MODELS_DIR, REPORTS_DIR, LOGS_DIR]:
|
| 37 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# ============================================================================
|
| 40 |
+
# VISUALISATION SETTINGS
|
| 41 |
+
# ============================================================================
|
| 42 |
+
|
| 43 |
+
def setup_visualization(
|
| 44 |
+
style: str = "seaborn-whitegrid",
|
| 45 |
+
palette: str = "husl",
|
| 46 |
+
context: str = "notebook",
|
| 47 |
+
font_scale: float = 1.0,
|
| 48 |
+
dpi: int = 150,
|
| 49 |
+
figsize: tuple = (12, 6),
|
| 50 |
+
**kwargs
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Configure visualisation parameters for matplotlib and seaborn
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
-----------
|
| 57 |
+
style : str
|
| 58 |
+
Matplotlib style: 'seaborn-whitegrid', 'ggplot', 'bmh', 'dark_background'
|
| 59 |
+
palette : str
|
| 60 |
+
Seaborn palette: 'husl', 'Set2', 'viridis', 'mako'
|
| 61 |
+
context : str
|
| 62 |
+
Seaborn context: 'paper', 'notebook', 'talk', 'poster'
|
| 63 |
+
font_scale : float
|
| 64 |
+
Font scale
|
| 65 |
+
dpi : int
|
| 66 |
+
Plot resolution
|
| 67 |
+
figsize : tuple
|
| 68 |
+
Default figure size
|
| 69 |
+
"""
|
| 70 |
+
# Ignore warnings
|
| 71 |
+
warnings.filterwarnings('ignore')
|
| 72 |
+
|
| 73 |
+
# Matplotlib settings
|
| 74 |
+
plt.style.use(style)
|
| 75 |
+
|
| 76 |
+
# RC parameters
|
| 77 |
+
rc_params = {
|
| 78 |
+
'font.size': 10,
|
| 79 |
+
'figure.figsize': figsize,
|
| 80 |
+
'figure.dpi': dpi,
|
| 81 |
+
'savefig.dpi': 300,
|
| 82 |
+
'savefig.bbox': 'tight',
|
| 83 |
+
'savefig.format': 'png',
|
| 84 |
+
'axes.titlesize': 12,
|
| 85 |
+
'axes.labelsize': 10,
|
| 86 |
+
'xtick.labelsize': 9,
|
| 87 |
+
'ytick.labelsize': 9,
|
| 88 |
+
'legend.fontsize': 9,
|
| 89 |
+
'font.family': ['DejaVu Sans', 'Arial', 'sans-serif'],
|
| 90 |
+
'figure.titlesize': 14,
|
| 91 |
+
'axes.grid': True,
|
| 92 |
+
'grid.alpha': 0.3,
|
| 93 |
+
'lines.linewidth': 1.5,
|
| 94 |
+
'lines.markersize': 6,
|
| 95 |
+
'patch.edgecolor': 'black',
|
| 96 |
+
'patch.force_edgecolor': True,
|
| 97 |
+
'xtick.top': False,
|
| 98 |
+
'ytick.right': False,
|
| 99 |
+
'axes.spines.top': False,
|
| 100 |
+
'axes.spines.right': False
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Update additional parameters
|
| 104 |
+
rc_params.update(kwargs)
|
| 105 |
+
plt.rcParams.update(rc_params)
|
| 106 |
+
|
| 107 |
+
# Seaborn settings
|
| 108 |
+
sns.set_style(style.replace('seaborn-', ''))
|
| 109 |
+
sns.set_palette(palette)
|
| 110 |
+
sns.set_context(context, font_scale=font_scale)
|
| 111 |
+
|
| 112 |
+
print(f"✓ Visualisation settings applied: style={style}, palette={palette}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_color_palette(name: str = "husl", n_colors: int = 8) -> list:
|
| 116 |
+
"""
|
| 117 |
+
Get colour palette
|
| 118 |
+
|
| 119 |
+
Parameters:
|
| 120 |
+
-----------
|
| 121 |
+
name : str
|
| 122 |
+
Palette name
|
| 123 |
+
n_colors : int
|
| 124 |
+
Number of colours
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
--------
|
| 128 |
+
list
|
| 129 |
+
List of colours in HEX format
|
| 130 |
+
"""
|
| 131 |
+
palette_map = {
|
| 132 |
+
"husl": sns.color_palette("husl", n_colors),
|
| 133 |
+
"Set2": sns.color_palette("Set2", n_colors),
|
| 134 |
+
"Set3": sns.color_palette("Set3", n_colors),
|
| 135 |
+
"viridis": sns.color_palette("viridis", n_colors),
|
| 136 |
+
"plasma": sns.color_palette("plasma", n_colors),
|
| 137 |
+
"coolwarm": sns.color_palette("coolwarm", n_colors),
|
| 138 |
+
"RdYlBu": sns.color_palette("RdYlBu", n_colors),
|
| 139 |
+
"Spectral": sns.color_palette("Spectral", n_colors),
|
| 140 |
+
"tab10": sns.color_palette("tab10", n_colors),
|
| 141 |
+
"tab20": sns.color_palette("tab20", n_colors),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
palette = palette_map.get(name, sns.color_palette("husl", n_colors))
|
| 145 |
+
return [f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
|
| 146 |
+
for r, g, b in palette]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ============================================================================
|
| 150 |
+
# CONSTANTS
|
| 151 |
+
# ============================================================================
|
| 152 |
+
|
| 153 |
+
# Data types
|
| 154 |
+
DATETIME_FORMATS = [
|
| 155 |
+
"%Y-%m-%d", "%Y/%m/%d", "%d.%m.%Y", "%d/%m/%Y",
|
| 156 |
+
"%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S",
|
| 157 |
+
"%d.%m.%Y %H:%M:%S", "%d/%m/%Y %H:%M:%S"
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
# Metrics
|
| 161 |
+
METRICS = {
|
| 162 |
+
"regression": ["mse", "rmse", "mae", "mape", "r2", "explained_variance"],
|
| 163 |
+
"classification": ["accuracy", "precision", "recall", "f1", "roc_auc"]
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Statistical constants
|
| 167 |
+
STATS_CONSTANTS = {
|
| 168 |
+
"confidence_levels": [0.9, 0.95, 0.99],
|
| 169 |
+
"z_scores": {0.9: 1.645, 0.95: 1.96, 0.99: 2.576},
|
| 170 |
+
"outlier_multipliers": {"mild": 1.5, "extreme": 3.0}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Time series parameters
|
| 174 |
+
TIME_SERIES_CONSTANTS = {
|
| 175 |
+
"frequencies": {
|
| 176 |
+
"H": "hourly",
|
| 177 |
+
"D": "daily",
|
| 178 |
+
"W": "weekly",
|
| 179 |
+
"M": "monthly",
|
| 180 |
+
"Q": "quarterly",
|
| 181 |
+
"Y": "yearly"
|
| 182 |
+
},
|
| 183 |
+
"seasonal_periods": {
|
| 184 |
+
"hourly": 24,
|
| 185 |
+
"daily": 7,
|
| 186 |
+
"weekly": 52,
|
| 187 |
+
"monthly": 12,
|
| 188 |
+
"quarterly": 4,
|
| 189 |
+
"yearly": 1
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
# ============================================================================
|
| 194 |
+
# CONFIGURATION UTILITIES
|
| 195 |
+
# ============================================================================
|
| 196 |
+
|
| 197 |
+
def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
| 198 |
+
"""
|
| 199 |
+
Load configuration from file
|
| 200 |
+
|
| 201 |
+
Parameters:
|
| 202 |
+
-----------
|
| 203 |
+
config_path : str, optional
|
| 204 |
+
Path to configuration file
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
--------
|
| 208 |
+
Dict[str, Any]
|
| 209 |
+
Configuration dictionary
|
| 210 |
+
"""
|
| 211 |
+
if config_path is None:
|
| 212 |
+
config_path = CONFIGS_DIR / "default_config.json"
|
| 213 |
+
|
| 214 |
+
config_path = Path(config_path)
|
| 215 |
+
|
| 216 |
+
if not config_path.exists():
|
| 217 |
+
print(f"⚠ Configuration file not found: {config_path}")
|
| 218 |
+
return {}
|
| 219 |
+
|
| 220 |
+
# Determine file format
|
| 221 |
+
if config_path.suffix.lower() in ['.json']:
|
| 222 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 223 |
+
config = json.load(f)
|
| 224 |
+
elif config_path.suffix.lower() in ['.yaml', '.yml']:
|
| 225 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 226 |
+
config = yaml.safe_load(f)
|
| 227 |
+
else:
|
| 228 |
+
raise ValueError(f"Unsupported file format: {config_path.suffix}")
|
| 229 |
+
|
| 230 |
+
print(f"✓ Configuration loaded from: {config_path}")
|
| 231 |
+
return config
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def save_config(config: Dict[str, Any], config_path: str) -> None:
|
| 235 |
+
"""
|
| 236 |
+
Save configuration to file
|
| 237 |
+
|
| 238 |
+
Parameters:
|
| 239 |
+
-----------
|
| 240 |
+
config : Dict[str, Any]
|
| 241 |
+
Configuration to save
|
| 242 |
+
config_path : str
|
| 243 |
+
Save path
|
| 244 |
+
"""
|
| 245 |
+
config_path = Path(config_path)
|
| 246 |
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
|
| 248 |
+
# Determine format
|
| 249 |
+
if config_path.suffix.lower() in ['.json']:
|
| 250 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 251 |
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
| 252 |
+
elif config_path.suffix.lower() in ['.yaml', '.yml']:
|
| 253 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 254 |
+
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
| 255 |
+
else:
|
| 256 |
+
raise ValueError(f"Unsupported file format: {config_path.suffix}")
|
| 257 |
+
|
| 258 |
+
print(f"✓ Configuration saved to: {config_path}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def merge_configs(base_config: Dict[str, Any],
|
| 262 |
+
override_config: Dict[str, Any]) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Recursive configuration merging
|
| 265 |
+
|
| 266 |
+
Parameters:
|
| 267 |
+
-----------
|
| 268 |
+
base_config : Dict[str, Any]
|
| 269 |
+
Base configuration
|
| 270 |
+
override_config : Dict[str, Any]
|
| 271 |
+
Override configuration
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
--------
|
| 275 |
+
Dict[str, Any]
|
| 276 |
+
Merged configuration
|
| 277 |
+
"""
|
| 278 |
+
result = base_config.copy()
|
| 279 |
+
|
| 280 |
+
for key, value in override_config.items():
|
| 281 |
+
if (key in result and isinstance(result[key], dict)
|
| 282 |
+
and isinstance(value, dict)):
|
| 283 |
+
result[key] = merge_configs(result[key], value)
|
| 284 |
+
else:
|
| 285 |
+
result[key] = value
|
| 286 |
+
|
| 287 |
+
return result
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ============================================================================
|
| 291 |
+
# ENVIRONMENT SETUP
|
| 292 |
+
# ============================================================================
|
| 293 |
+
|
| 294 |
+
def setup_environment(
|
| 295 |
+
log_level: str = "INFO",
|
| 296 |
+
random_seed: int = 42,
|
| 297 |
+
enable_warnings: bool = False,
|
| 298 |
+
memory_limit_gb: Optional[int] = None
|
| 299 |
+
) -> None:
|
| 300 |
+
"""
|
| 301 |
+
Set up environment for reproducibility
|
| 302 |
+
|
| 303 |
+
Parameters:
|
| 304 |
+
-----------
|
| 305 |
+
log_level : str
|
| 306 |
+
Logging level
|
| 307 |
+
random_seed : int
|
| 308 |
+
Seed for random generators
|
| 309 |
+
enable_warnings : bool
|
| 310 |
+
Enable warnings
|
| 311 |
+
memory_limit_gb : int, optional
|
| 312 |
+
Memory limit in GB
|
| 313 |
+
"""
|
| 314 |
+
import numpy as np
|
| 315 |
+
import random
|
| 316 |
+
import torch
|
| 317 |
+
import tensorflow as tf
|
| 318 |
+
|
| 319 |
+
# Set seeds
|
| 320 |
+
np.random.seed(random_seed)
|
| 321 |
+
random.seed(random_seed)
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
torch.manual_seed(random_seed)
|
| 325 |
+
except:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
tf.random.set_seed(random_seed)
|
| 330 |
+
except:
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
# Configure warnings
|
| 334 |
+
if enable_warnings:
|
| 335 |
+
warnings.filterwarnings('default')
|
| 336 |
+
else:
|
| 337 |
+
warnings.filterwarnings('ignore')
|
| 338 |
+
|
| 339 |
+
# Memory limit (if specified)
|
| 340 |
+
if memory_limit_gb:
|
| 341 |
+
import resource
|
| 342 |
+
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
| 343 |
+
memory_limit = memory_limit_gb * 1024**3 # GB to bytes
|
| 344 |
+
resource.setrlimit(resource.RLIMIT_AS, (memory_limit, hard))
|
| 345 |
+
print(f"✓ Memory limit set: {memory_limit_gb} GB")
|
| 346 |
+
|
| 347 |
+
print(f"✓ Environment configured. Random seed: {random_seed}")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ============================================================================
|
| 351 |
+
# AUTOMATIC SETUP ON IMPORT
|
| 352 |
+
# ============================================================================
|
| 353 |
+
|
| 354 |
+
# Automatically apply visualisation settings
|
| 355 |
+
setup_visualization()
|
| 356 |
+
|
| 357 |
+
# Export useful variables
|
| 358 |
+
__all__ = [
|
| 359 |
+
'setup_visualization',
|
| 360 |
+
'get_color_palette',
|
| 361 |
+
'load_config',
|
| 362 |
+
'save_config',
|
| 363 |
+
'merge_configs',
|
| 364 |
+
'setup_environment',
|
| 365 |
+
'PROJECT_ROOT',
|
| 366 |
+
'DATA_DIR',
|
| 367 |
+
'RAW_DATA_DIR',
|
| 368 |
+
'PROCESSED_DATA_DIR',
|
| 369 |
+
'RESULTS_DIR',
|
| 370 |
+
'PLOTS_DIR',
|
| 371 |
+
'DATETIME_FORMATS',
|
| 372 |
+
'METRICS',
|
| 373 |
+
'STATS_CONSTANTS',
|
| 374 |
+
'TIME_SERIES_CONSTANTS'
|
| 375 |
+
]
|
correlations/__init__.py
ADDED
|
File without changes
|
correlations/correlation_analyzer.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 8: CORRELATION AND MULTICOLLINEARITY ANALYSIS
|
| 3 |
+
# ============================================
|
| 4 |
+
import os
|
| 5 |
+
import traceback
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
from venv import logger
|
| 8 |
+
|
| 9 |
+
from config.config import Config
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
class CorrelationAnalyzer:
|
| 14 |
+
"""Class for comprehensive correlation and multicollinearity analysis"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config: Config):
|
| 17 |
+
"""
|
| 18 |
+
Initialise the analyser
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
-----------
|
| 22 |
+
config : Config
|
| 23 |
+
Experiment configuration
|
| 24 |
+
"""
|
| 25 |
+
self.config = config
|
| 26 |
+
self.correlation_matrices = {}
|
| 27 |
+
self.high_correlation_pairs = {}
|
| 28 |
+
self.multicollinearity_info = {}
|
| 29 |
+
self.vif_scores = {}
|
| 30 |
+
|
| 31 |
+
def analyze(
|
| 32 |
+
self,
|
| 33 |
+
data: pd.DataFrame,
|
| 34 |
+
target_col: Optional[str] = None,
|
| 35 |
+
threshold: float = 0.8,
|
| 36 |
+
detailed: bool = True,
|
| 37 |
+
**kwargs
|
| 38 |
+
) -> pd.DataFrame:
|
| 39 |
+
"""
|
| 40 |
+
Analyse correlations in the data
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
-----------
|
| 44 |
+
data : pd.DataFrame
|
| 45 |
+
Input data
|
| 46 |
+
target_col : str, optional
|
| 47 |
+
Target variable
|
| 48 |
+
threshold : float
|
| 49 |
+
Threshold for identifying high correlations
|
| 50 |
+
detailed : bool
|
| 51 |
+
Whether to perform detailed analysis
|
| 52 |
+
**kwargs : dict
|
| 53 |
+
Additional parameters
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
--------
|
| 57 |
+
pd.DataFrame
|
| 58 |
+
Correlation matrix
|
| 59 |
+
"""
|
| 60 |
+
logger.info("\n" + "="*80)
|
| 61 |
+
logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS")
|
| 62 |
+
logger.info("="*80)
|
| 63 |
+
|
| 64 |
+
target_col = target_col or self.config.target_column
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# 1. Calculate correlation matrix
|
| 68 |
+
corr_matrix = self._compute_correlations(data, target_col)
|
| 69 |
+
|
| 70 |
+
if corr_matrix.empty:
|
| 71 |
+
logger.warning("Correlation matrix is empty")
|
| 72 |
+
return pd.DataFrame()
|
| 73 |
+
|
| 74 |
+
# 2. Identify high correlations
|
| 75 |
+
high_correlations = self._detect_high_correlations(corr_matrix, threshold)
|
| 76 |
+
self.high_correlation_pairs['pearson'] = high_correlations
|
| 77 |
+
|
| 78 |
+
# 3. Analyse correlations with target variable
|
| 79 |
+
target_correlations = []
|
| 80 |
+
if target_col in corr_matrix.columns:
|
| 81 |
+
target_correlations = self._get_target_correlations(corr_matrix, target_col)
|
| 82 |
+
|
| 83 |
+
# 4. Analyse multicollinearity (VIF)
|
| 84 |
+
vif_results = self._compute_vif_scores(data)
|
| 85 |
+
|
| 86 |
+
# 5. Detailed analysis if required
|
| 87 |
+
if detailed:
|
| 88 |
+
self._detailed_correlation_analysis(data, corr_matrix, target_col)
|
| 89 |
+
|
| 90 |
+
# 6. Visualisation
|
| 91 |
+
if self.config.save_plots:
|
| 92 |
+
self._plot_correlation_analysis(data, corr_matrix, target_col, high_correlations, vif_results)
|
| 93 |
+
|
| 94 |
+
# 7. Output results
|
| 95 |
+
self._log_analysis_results(corr_matrix, high_correlations, target_correlations, vif_results)
|
| 96 |
+
|
| 97 |
+
return corr_matrix
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Error in correlation analysis: {e}")
|
| 101 |
+
logger.error(traceback.format_exc())
|
| 102 |
+
return pd.DataFrame()
|
| 103 |
+
|
| 104 |
+
def _compute_correlations(
|
| 105 |
+
self,
|
| 106 |
+
data: pd.DataFrame,
|
| 107 |
+
target_col: str
|
| 108 |
+
) -> pd.DataFrame:
|
| 109 |
+
"""Calculate correlation matrix"""
|
| 110 |
+
logger.info("Calculating correlation matrix...")
|
| 111 |
+
|
| 112 |
+
# Select only numeric columns
|
| 113 |
+
numeric_data = data.select_dtypes(include=[np.number])
|
| 114 |
+
|
| 115 |
+
# Remove constant columns
|
| 116 |
+
numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
|
| 117 |
+
|
| 118 |
+
if numeric_data.shape[1] < 2:
|
| 119 |
+
logger.warning("Insufficient numeric features for analysis")
|
| 120 |
+
return pd.DataFrame()
|
| 121 |
+
|
| 122 |
+
# Remove missing values
|
| 123 |
+
numeric_data_clean = numeric_data.dropna()
|
| 124 |
+
|
| 125 |
+
if len(numeric_data_clean) < 10:
|
| 126 |
+
logger.warning("Insufficient data after cleaning")
|
| 127 |
+
return pd.DataFrame()
|
| 128 |
+
|
| 129 |
+
# Calculate Pearson correlation
|
| 130 |
+
try:
|
| 131 |
+
corr_matrix = numeric_data_clean.corr(method='pearson')
|
| 132 |
+
self.correlation_matrices['pearson'] = corr_matrix
|
| 133 |
+
logger.info(f"✓ Correlation matrix calculated: {corr_matrix.shape}")
|
| 134 |
+
return corr_matrix
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error calculating correlation: {e}")
|
| 137 |
+
return pd.DataFrame()
|
| 138 |
+
|
| 139 |
+
def _detect_high_correlations(
|
| 140 |
+
self,
|
| 141 |
+
corr_matrix: pd.DataFrame,
|
| 142 |
+
threshold: float = 0.8
|
| 143 |
+
) -> List[Dict[str, Any]]:
|
| 144 |
+
"""Detect high correlations"""
|
| 145 |
+
high_correlations = []
|
| 146 |
+
|
| 147 |
+
if corr_matrix.empty:
|
| 148 |
+
return high_correlations
|
| 149 |
+
|
| 150 |
+
# Use upper triangle of matrix
|
| 151 |
+
upper_triangle = corr_matrix.where(
|
| 152 |
+
np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Find pairs with correlation above threshold
|
| 156 |
+
for col in upper_triangle.columns:
|
| 157 |
+
if col in upper_triangle:
|
| 158 |
+
high_corr_series = upper_triangle[col][abs(upper_triangle[col]) > threshold]
|
| 159 |
+
|
| 160 |
+
for row_idx, correlation in high_corr_series.items():
|
| 161 |
+
if not pd.isna(correlation):
|
| 162 |
+
high_correlations.append({
|
| 163 |
+
'feature1': row_idx,
|
| 164 |
+
'feature2': col,
|
| 165 |
+
'correlation': float(correlation),
|
| 166 |
+
'abs_correlation': abs(float(correlation))
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
# Sort by absolute correlation value
|
| 170 |
+
high_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
|
| 171 |
+
|
| 172 |
+
logger.info(f"High correlations detected (> {threshold}): {len(high_correlations)}")
|
| 173 |
+
return high_correlations
|
| 174 |
+
|
| 175 |
+
def _get_target_correlations(
|
| 176 |
+
self,
|
| 177 |
+
corr_matrix: pd.DataFrame,
|
| 178 |
+
target_col: str
|
| 179 |
+
) -> List[Dict[str, Any]]:
|
| 180 |
+
"""Get correlations with target variable"""
|
| 181 |
+
target_correlations = []
|
| 182 |
+
|
| 183 |
+
if target_col not in corr_matrix.columns:
|
| 184 |
+
return target_correlations
|
| 185 |
+
|
| 186 |
+
# Extract correlations with target variable
|
| 187 |
+
target_corr_series = corr_matrix[target_col]
|
| 188 |
+
|
| 189 |
+
for feature, correlation in target_corr_series.items():
|
| 190 |
+
if feature != target_col and not pd.isna(correlation):
|
| 191 |
+
target_correlations.append({
|
| 192 |
+
'feature': feature,
|
| 193 |
+
'correlation': float(correlation),
|
| 194 |
+
'abs_correlation': abs(float(correlation)),
|
| 195 |
+
'direction': 'positive' if correlation > 0 else 'negative'
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
# Sort by absolute value
|
| 199 |
+
target_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)
|
| 200 |
+
|
| 201 |
+
logger.info(f"Correlations with target variable calculated: {len(target_correlations)}")
|
| 202 |
+
return target_correlations
|
| 203 |
+
|
| 204 |
+
def _compute_vif_scores(self, data: pd.DataFrame) -> Dict[str, Any]:
|
| 205 |
+
"""Calculate VIF (Variance Inflation Factor)"""
|
| 206 |
+
logger.info("Analysing multicollinearity (VIF)...")
|
| 207 |
+
|
| 208 |
+
vif_results = {
|
| 209 |
+
'scores': {},
|
| 210 |
+
'issues': [],
|
| 211 |
+
'summary': {
|
| 212 |
+
'critical': 0,
|
| 213 |
+
'high': 0,
|
| 214 |
+
'medium': 0,
|
| 215 |
+
'low': 0
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
| 221 |
+
import statsmodels.api as sm
|
| 222 |
+
|
| 223 |
+
# Prepare data
|
| 224 |
+
numeric_data = data.select_dtypes(include=[np.number])
|
| 225 |
+
numeric_data = numeric_data.loc[:, numeric_data.nunique() > 1]
|
| 226 |
+
|
| 227 |
+
# Remove missing and infinite values
|
| 228 |
+
clean_data = numeric_data.replace([np.inf, -np.inf], np.nan).dropna()
|
| 229 |
+
|
| 230 |
+
if clean_data.shape[0] < 10 or clean_data.shape[1] < 2:
|
| 231 |
+
logger.warning("Insufficient data for VIF analysis")
|
| 232 |
+
return vif_results
|
| 233 |
+
|
| 234 |
+
# Add constant
|
| 235 |
+
X = sm.add_constant(clean_data, has_constant='add')
|
| 236 |
+
|
| 237 |
+
# Calculate VIF for each feature
|
| 238 |
+
vif_scores = {}
|
| 239 |
+
for i, column in enumerate(X.columns):
|
| 240 |
+
if column == 'const':
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
vif = variance_inflation_factor(X.values, i)
|
| 245 |
+
|
| 246 |
+
# Handle extreme values
|
| 247 |
+
if np.isinf(vif) or vif > 1e6:
|
| 248 |
+
vif = 1e6
|
| 249 |
+
|
| 250 |
+
vif_scores[column] = float(vif)
|
| 251 |
+
|
| 252 |
+
# Classify by severity
|
| 253 |
+
if vif > 100:
|
| 254 |
+
vif_results['summary']['critical'] += 1
|
| 255 |
+
vif_results['issues'].append({
|
| 256 |
+
'feature': column,
|
| 257 |
+
'vif': float(vif),
|
| 258 |
+
'severity': 'critical',
|
| 259 |
+
'recommendation': 'Remove feature'
|
| 260 |
+
})
|
| 261 |
+
elif vif > 10:
|
| 262 |
+
vif_results['summary']['high'] += 1
|
| 263 |
+
vif_results['issues'].append({
|
| 264 |
+
'feature': column,
|
| 265 |
+
'vif': float(vif),
|
| 266 |
+
'severity': 'high',
|
| 267 |
+
'recommendation': 'Consider removal'
|
| 268 |
+
})
|
| 269 |
+
elif vif > 5:
|
| 270 |
+
vif_results['summary']['medium'] += 1
|
| 271 |
+
else:
|
| 272 |
+
vif_results['summary']['low'] += 1
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.warning(f"VIF error for {column}: {e}")
|
| 276 |
+
vif_scores[column] = np.nan
|
| 277 |
+
|
| 278 |
+
vif_results['scores'] = vif_scores
|
| 279 |
+
self.vif_scores = vif_scores
|
| 280 |
+
|
| 281 |
+
logger.info(f"✓ VIF analysis completed. Critical features: {vif_results['summary']['critical']}")
|
| 282 |
+
|
| 283 |
+
except ImportError:
|
| 284 |
+
logger.warning("statsmodels not installed, skipping VIF analysis")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"VIF analysis error: {e}")
|
| 287 |
+
|
| 288 |
+
return vif_results
|
| 289 |
+
|
| 290 |
+
def _detailed_correlation_analysis(
|
| 291 |
+
self,
|
| 292 |
+
data: pd.DataFrame,
|
| 293 |
+
corr_matrix: pd.DataFrame,
|
| 294 |
+
target_col: str
|
| 295 |
+
) -> None:
|
| 296 |
+
"""Detailed correlation analysis"""
|
| 297 |
+
# Analyse correlation clusters
|
| 298 |
+
if not corr_matrix.empty and corr_matrix.shape[0] > 3:
|
| 299 |
+
try:
|
| 300 |
+
# Use clustering to group correlated features
|
| 301 |
+
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
|
| 302 |
+
from scipy.spatial.distance import squareform
|
| 303 |
+
|
| 304 |
+
# Convert correlations to distances
|
| 305 |
+
distance_matrix = 1 - abs(corr_matrix)
|
| 306 |
+
np.fill_diagonal(distance_matrix.values, 0)
|
| 307 |
+
|
| 308 |
+
# Clustering
|
| 309 |
+
condensed_dist = squareform(distance_matrix)
|
| 310 |
+
Z = linkage(condensed_dist, method='average')
|
| 311 |
+
|
| 312 |
+
# Determine clusters
|
| 313 |
+
clusters = fcluster(Z, t=0.5, criterion='distance')
|
| 314 |
+
|
| 315 |
+
# Group features by cluster
|
| 316 |
+
feature_clusters = {}
|
| 317 |
+
for idx, cluster_id in enumerate(clusters):
|
| 318 |
+
feature = corr_matrix.columns[idx]
|
| 319 |
+
if cluster_id not in feature_clusters:
|
| 320 |
+
feature_clusters[cluster_id] = []
|
| 321 |
+
feature_clusters[cluster_id].append(feature)
|
| 322 |
+
|
| 323 |
+
# Save cluster information
|
| 324 |
+
self.multicollinearity_info['correlation_clusters'] = feature_clusters
|
| 325 |
+
logger.info(f"Correlated feature clusters detected: {len(feature_clusters)}")
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logger.debug(f"Cluster analysis failed: {e}")
|
| 329 |
+
|
| 330 |
+
def _plot_correlation_analysis(
|
| 331 |
+
self,
|
| 332 |
+
data: pd.DataFrame,
|
| 333 |
+
corr_matrix: pd.DataFrame,
|
| 334 |
+
target_col: str,
|
| 335 |
+
high_correlations: List[Dict[str, Any]],
|
| 336 |
+
vif_results: Dict[str, Any]
|
| 337 |
+
) -> None:
|
| 338 |
+
"""Visualise correlation analysis"""
|
| 339 |
+
try:
|
| 340 |
+
import matplotlib.pyplot as plt
|
| 341 |
+
import seaborn as sns
|
| 342 |
+
from matplotlib import rcParams
|
| 343 |
+
|
| 344 |
+
# Style settings
|
| 345 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 346 |
+
rcParams.update({
|
| 347 |
+
'figure.figsize': (12, 8),
|
| 348 |
+
'font.size': 10,
|
| 349 |
+
'axes.titlesize': 14,
|
| 350 |
+
'axes.labelsize': 12
|
| 351 |
+
})
|
| 352 |
+
|
| 353 |
+
# Create directory
|
| 354 |
+
plots_dir = os.path.join(self.config.results_dir, 'plots', 'correlations')
|
| 355 |
+
os.makedirs(plots_dir, exist_ok=True)
|
| 356 |
+
|
| 357 |
+
# 1. Correlation matrix heatmap
|
| 358 |
+
if not corr_matrix.empty and corr_matrix.shape[0] > 1:
|
| 359 |
+
fig, ax = plt.subplots(figsize=(14, 12))
|
| 360 |
+
|
| 361 |
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
| 362 |
+
sns.heatmap(
|
| 363 |
+
corr_matrix,
|
| 364 |
+
mask=mask,
|
| 365 |
+
annot=True,
|
| 366 |
+
fmt='.2f',
|
| 367 |
+
cmap='coolwarm',
|
| 368 |
+
center=0,
|
| 369 |
+
square=True,
|
| 370 |
+
linewidths=0.5,
|
| 371 |
+
cbar_kws={"shrink": 0.8},
|
| 372 |
+
ax=ax
|
| 373 |
+
)
|
| 374 |
+
ax.set_title('Correlation Matrix (Pearson)', fontweight='bold')
|
| 375 |
+
plt.tight_layout()
|
| 376 |
+
plt.savefig(os.path.join(plots_dir, 'correlation_matrix.png'),
|
| 377 |
+
dpi=150, bbox_inches='tight')
|
| 378 |
+
plt.close()
|
| 379 |
+
|
| 380 |
+
# 2. Target variable correlations
|
| 381 |
+
if target_col in corr_matrix.columns:
|
| 382 |
+
target_corrs = corr_matrix[target_col].drop(target_col, errors='ignore')
|
| 383 |
+
if not target_corrs.empty:
|
| 384 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 385 |
+
|
| 386 |
+
top_corrs = target_corrs.abs().sort_values(ascending=True).tail(20)
|
| 387 |
+
colors = ['red' if target_corrs[feat] < 0 else 'blue'
|
| 388 |
+
for feat in top_corrs.index]
|
| 389 |
+
|
| 390 |
+
ax.barh(range(len(top_corrs)), top_corrs.values, color=colors)
|
| 391 |
+
ax.set_yticks(range(len(top_corrs)))
|
| 392 |
+
ax.set_yticklabels(top_corrs.index)
|
| 393 |
+
ax.set_xlabel('Absolute correlation')
|
| 394 |
+
ax.set_title(f'Top-20 correlations with {target_col}', fontweight='bold')
|
| 395 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 396 |
+
|
| 397 |
+
plt.tight_layout()
|
| 398 |
+
plt.savefig(os.path.join(plots_dir, 'target_correlations.png'),
|
| 399 |
+
dpi=150, bbox_inches='tight')
|
| 400 |
+
plt.close()
|
| 401 |
+
|
| 402 |
+
# 3. VIF scores plot
|
| 403 |
+
if vif_results['scores']:
|
| 404 |
+
valid_scores = {k: v for k, v in vif_results['scores'].items()
|
| 405 |
+
if not pd.isna(v)}
|
| 406 |
+
if valid_scores:
|
| 407 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 408 |
+
|
| 409 |
+
sorted_scores = dict(sorted(valid_scores.items(),
|
| 410 |
+
key=lambda x: x[1],
|
| 411 |
+
reverse=True)[:25])
|
| 412 |
+
|
| 413 |
+
colors = []
|
| 414 |
+
for vif in sorted_scores.values():
|
| 415 |
+
if vif > 100:
|
| 416 |
+
colors.append('red')
|
| 417 |
+
elif vif > 10:
|
| 418 |
+
colors.append('orange')
|
| 419 |
+
elif vif > 5:
|
| 420 |
+
colors.append('yellow')
|
| 421 |
+
else:
|
| 422 |
+
colors.append('green')
|
| 423 |
+
|
| 424 |
+
bars = ax.barh(list(sorted_scores.keys()),
|
| 425 |
+
list(sorted_scores.values()),
|
| 426 |
+
color=colors, edgecolor='black')
|
| 427 |
+
|
| 428 |
+
ax.set_xlabel('VIF Score')
|
| 429 |
+
ax.set_title('VIF Scores (multicollinearity)', fontweight='bold')
|
| 430 |
+
ax.axvline(x=5, color='yellow', linestyle='--', alpha=0.7)
|
| 431 |
+
ax.axvline(x=10, color='orange', linestyle='--', alpha=0.7)
|
| 432 |
+
ax.axvline(x=100, color='red', linestyle='--', alpha=0.7)
|
| 433 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 434 |
+
|
| 435 |
+
plt.tight_layout()
|
| 436 |
+
plt.savefig(os.path.join(plots_dir, 'vif_scores.png'),
|
| 437 |
+
dpi=150, bbox_inches='tight')
|
| 438 |
+
plt.close()
|
| 439 |
+
|
| 440 |
+
# 4. High correlations plot
|
| 441 |
+
if high_correlations:
|
| 442 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 443 |
+
|
| 444 |
+
# Limit number for display
|
| 445 |
+
display_corrs = high_correlations[:15]
|
| 446 |
+
|
| 447 |
+
# Create labels for feature pairs
|
| 448 |
+
labels = [f"{corr['feature1']} ↔ {corr['feature2']}"
|
| 449 |
+
for corr in display_corrs]
|
| 450 |
+
values = [corr['correlation'] for corr in display_corrs]
|
| 451 |
+
colors = ['red' if v < 0 else 'blue' for v in values]
|
| 452 |
+
|
| 453 |
+
y_pos = np.arange(len(display_corrs))
|
| 454 |
+
ax.barh(y_pos, values, color=colors)
|
| 455 |
+
ax.set_yticks(y_pos)
|
| 456 |
+
ax.set_yticklabels(labels, fontsize=9)
|
| 457 |
+
ax.invert_yaxis()
|
| 458 |
+
ax.set_xlabel('Correlation')
|
| 459 |
+
ax.set_title('High correlations (> 0.8)', fontweight='bold')
|
| 460 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 461 |
+
|
| 462 |
+
plt.tight_layout()
|
| 463 |
+
plt.savefig(os.path.join(plots_dir, 'high_correlations.png'),
|
| 464 |
+
dpi=150, bbox_inches='tight')
|
| 465 |
+
plt.close()
|
| 466 |
+
|
| 467 |
+
logger.info(f"Visualisations saved to {plots_dir}")
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
logger.warning(f"Error creating visualisations: {e}")
|
| 471 |
+
|
| 472 |
+
def _log_analysis_results(
|
| 473 |
+
self,
|
| 474 |
+
corr_matrix: pd.DataFrame,
|
| 475 |
+
high_correlations: List[Dict[str, Any]],
|
| 476 |
+
target_correlations: List[Dict[str, Any]],
|
| 477 |
+
vif_results: Dict[str, Any]
|
| 478 |
+
) -> None:
|
| 479 |
+
"""Log analysis results"""
|
| 480 |
+
logger.info("\n" + "="*80)
|
| 481 |
+
logger.info("CORRELATION AND MULTICOLLINEARITY ANALYSIS REPORT")
|
| 482 |
+
logger.info("="*80)
|
| 483 |
+
|
| 484 |
+
# General information
|
| 485 |
+
logger.info(f"\n📊 GENERAL INFORMATION:")
|
| 486 |
+
logger.info(f" Correlation matrix size: {corr_matrix.shape}")
|
| 487 |
+
logger.info(f" Total features: {len(corr_matrix.columns)}")
|
| 488 |
+
|
| 489 |
+
# High correlations
|
| 490 |
+
if high_correlations:
|
| 491 |
+
logger.info(f"\n⚠ HIGH CORRELATIONS (|r| > 0.8): {len(high_correlations)}")
|
| 492 |
+
logger.info(" " + "-" * 60)
|
| 493 |
+
|
| 494 |
+
for i, corr in enumerate(high_correlations[:10]):
|
| 495 |
+
sign = "🟥" if corr['correlation'] < 0 else "🟩"
|
| 496 |
+
logger.info(f" {i+1:2d}. {sign} {corr['feature1']:25s} ↔ {corr['feature2']:25s}: {corr['correlation']:7.4f}")
|
| 497 |
+
|
| 498 |
+
if len(high_correlations) > 10:
|
| 499 |
+
logger.info(f" ... and {len(high_correlations) - 10} more pairs")
|
| 500 |
+
|
| 501 |
+
# Target variable correlations
|
| 502 |
+
if target_correlations:
|
| 503 |
+
logger.info(f"\n🎯 CORRELATIONS WITH TARGET VARIABLE:")
|
| 504 |
+
logger.info(" " + "-" * 60)
|
| 505 |
+
|
| 506 |
+
for i, corr in enumerate(target_correlations[:10]):
|
| 507 |
+
direction = "↓" if corr['correlation'] < 0 else "↑"
|
| 508 |
+
logger.info(f" {i+1:2d}. {direction} {corr['feature']:35s}: {corr['correlation']:7.4f}")
|
| 509 |
+
|
| 510 |
+
# Multicollinearity analysis
|
| 511 |
+
if vif_results['scores']:
|
| 512 |
+
logger.info(f"\n📈 MULTICOLLINEARITY ANALYSIS (VIF):")
|
| 513 |
+
logger.info(" " + "-" * 60)
|
| 514 |
+
logger.info(f" Critical (VIF > 100): {vif_results['summary']['critical']}")
|
| 515 |
+
logger.info(f" High (10 < VIF ≤ 100): {vif_results['summary']['high']}")
|
| 516 |
+
logger.info(f" Medium (5 < VIF ≤ 10): {vif_results['summary']['medium']}")
|
| 517 |
+
logger.info(f" Low (VIF ≤ 5): {vif_results['summary']['low']}")
|
| 518 |
+
|
| 519 |
+
# Top problematic features
|
| 520 |
+
if vif_results['issues']:
|
| 521 |
+
logger.info(f"\n🔴 PROBLEMATIC FEATURES (VIF > 10):")
|
| 522 |
+
for issue in vif_results['issues'][:10]:
|
| 523 |
+
logger.info(f" • {issue['feature']:35s}: VIF = {issue['vif']:7.1f} ({issue['severity']})")
|
| 524 |
+
|
| 525 |
+
logger.info("\n" + "="*80)
|
| 526 |
+
logger.info("RECOMMENDATIONS:")
|
| 527 |
+
logger.info("="*80)
|
| 528 |
+
|
| 529 |
+
# Generate recommendations
|
| 530 |
+
recommendations = []
|
| 531 |
+
|
| 532 |
+
if len(high_correlations) > 20:
|
| 533 |
+
recommendations.append("1. Remove highly correlated features (correlation method)")
|
| 534 |
+
|
| 535 |
+
if vif_results['summary']['critical'] > 0:
|
| 536 |
+
recommendations.append("2. Remove features with critical VIF (>100)")
|
| 537 |
+
|
| 538 |
+
if vif_results['summary']['high'] > 5:
|
| 539 |
+
recommendations.append("3. Consider removing features with VIF > 10")
|
| 540 |
+
|
| 541 |
+
if not recommendations:
|
| 542 |
+
recommendations.append("1. Data in good condition, no serious issues detected")
|
| 543 |
+
recommendations.append("2. Proceed to modelling")
|
| 544 |
+
|
| 545 |
+
for i, rec in enumerate(recommendations, 1):
|
| 546 |
+
logger.info(f" {rec}")
|
| 547 |
+
|
| 548 |
+
logger.info("\n" + "="*80)
|
| 549 |
+
|
| 550 |
+
def remove_highly_correlated(
|
| 551 |
+
self,
|
| 552 |
+
data: pd.DataFrame,
|
| 553 |
+
threshold: float = 0.85,
|
| 554 |
+
method: str = 'variance',
|
| 555 |
+
keep_target: bool = True,
|
| 556 |
+
keep_features: List[str] = None
|
| 557 |
+
) -> pd.DataFrame:
|
| 558 |
+
"""
|
| 559 |
+
Remove highly correlated features
|
| 560 |
+
|
| 561 |
+
Parameters:
|
| 562 |
+
-----------
|
| 563 |
+
data : pd.DataFrame
|
| 564 |
+
Source data
|
| 565 |
+
threshold : float
|
| 566 |
+
Correlation threshold for removal
|
| 567 |
+
method : str
|
| 568 |
+
Feature selection method for removal: 'variance', 'random', 'importance'
|
| 569 |
+
keep_target : bool
|
| 570 |
+
Whether to keep target variable
|
| 571 |
+
keep_features : List[str], optional
|
| 572 |
+
Features to keep
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
--------
|
| 576 |
+
pd.DataFrame
|
| 577 |
+
Data after removing highly correlated features
|
| 578 |
+
"""
|
| 579 |
+
logger.info("\n" + "="*80)
|
| 580 |
+
logger.info("REMOVING HIGHLY CORRELATED FEATURES")
|
| 581 |
+
logger.info("="*80)
|
| 582 |
+
|
| 583 |
+
data_clean = data.copy()
|
| 584 |
+
|
| 585 |
+
if 'pearson' not in self.correlation_matrices:
|
| 586 |
+
logger.warning("Correlation matrix not calculated, run analyze() first")
|
| 587 |
+
return data_clean
|
| 588 |
+
|
| 589 |
+
corr_matrix = self.correlation_matrices['pearson']
|
| 590 |
+
|
| 591 |
+
# Features to keep
|
| 592 |
+
features_to_keep = set()
|
| 593 |
+
|
| 594 |
+
if keep_target and self.config.target_column in data_clean.columns:
|
| 595 |
+
features_to_keep.add(self.config.target_column)
|
| 596 |
+
|
| 597 |
+
if keep_features:
|
| 598 |
+
for feat in keep_features:
|
| 599 |
+
if feat in data_clean.columns:
|
| 600 |
+
features_to_keep.add(feat)
|
| 601 |
+
|
| 602 |
+
# Temporal features (usually important for time series)
|
| 603 |
+
temporal_patterns = ['year', 'month', 'day', 'week', 'quarter',
|
| 604 |
+
'hour', 'minute', 'second', 'sin', 'cos']
|
| 605 |
+
|
| 606 |
+
for col in data_clean.columns:
|
| 607 |
+
if any(pattern in col.lower() for pattern in temporal_patterns):
|
| 608 |
+
features_to_keep.add(col)
|
| 609 |
+
|
| 610 |
+
# Find highly correlated pairs
|
| 611 |
+
upper_triangle = corr_matrix.where(
|
| 612 |
+
np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# Collect highly correlated features
|
| 616 |
+
correlated_features = set()
|
| 617 |
+
for col in upper_triangle.columns:
|
| 618 |
+
if col in features_to_keep:
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
high_corr = upper_triangle[col][abs(upper_triangle[col]) > threshold]
|
| 622 |
+
for row_idx, corr_value in high_corr.items():
|
| 623 |
+
if not pd.isna(corr_value) and row_idx not in features_to_keep:
|
| 624 |
+
# Select which feature to remove
|
| 625 |
+
if method == 'variance':
|
| 626 |
+
# Remove the one with lower variance
|
| 627 |
+
var_col = data_clean[col].var()
|
| 628 |
+
var_row = data_clean[row_idx].var()
|
| 629 |
+
feature_to_remove = col if var_col < var_row else row_idx
|
| 630 |
+
elif method == 'importance':
|
| 631 |
+
# Remove the one with lower correlation to target variable
|
| 632 |
+
if self.config.target_column in corr_matrix.columns:
|
| 633 |
+
corr_col_target = abs(corr_matrix.loc[col, self.config.target_column])
|
| 634 |
+
corr_row_target = abs(corr_matrix.loc[row_idx, self.config.target_column])
|
| 635 |
+
feature_to_remove = col if corr_col_target < corr_row_target else row_idx
|
| 636 |
+
else:
|
| 637 |
+
# If no target, remove randomly
|
| 638 |
+
feature_to_remove = np.random.choice([col, row_idx])
|
| 639 |
+
else:
|
| 640 |
+
# Remove randomly
|
| 641 |
+
feature_to_remove = np.random.choice([col, row_idx])
|
| 642 |
+
|
| 643 |
+
correlated_features.add(feature_to_remove)
|
| 644 |
+
|
| 645 |
+
# Remove features
|
| 646 |
+
features_to_remove = list(correlated_features)
|
| 647 |
+
|
| 648 |
+
if features_to_remove:
|
| 649 |
+
data_clean = data_clean.drop(columns=features_to_remove)
|
| 650 |
+
|
| 651 |
+
logger.info(f"\n📊 REMOVAL RESULTS:")
|
| 652 |
+
logger.info(f" Initial feature count: {len(data.columns)}")
|
| 653 |
+
logger.info(f" Features removed: {len(features_to_remove)}")
|
| 654 |
+
logger.info(f" Final feature count: {len(data_clean.columns)}")
|
| 655 |
+
logger.info(f" Retained: {len(data_clean.columns)/len(data.columns)*100:.1f}%")
|
| 656 |
+
|
| 657 |
+
if features_to_remove:
|
| 658 |
+
logger.info(f"\n🗑️ REMOVED FEATURES:")
|
| 659 |
+
for i, feat in enumerate(sorted(features_to_remove)[:20]):
|
| 660 |
+
logger.info(f" {i+1:2d}. {feat}")
|
| 661 |
+
if len(features_to_remove) > 20:
|
| 662 |
+
logger.info(f" ... and {len(features_to_remove) - 20} more features")
|
| 663 |
+
else:
|
| 664 |
+
logger.info("✓ No highly correlated features detected, all features retained")
|
| 665 |
+
|
| 666 |
+
logger.info("="*80)
|
| 667 |
+
return data_clean
|
| 668 |
+
|
| 669 |
+
def get_report(self) -> Dict[str, Any]:
|
| 670 |
+
"""Get analysis report"""
|
| 671 |
+
report = {
|
| 672 |
+
"correlation_matrix_shape": None,
|
| 673 |
+
"high_correlation_count": 0,
|
| 674 |
+
"vif_summary": {},
|
| 675 |
+
"target_correlation_count": 0
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
if 'pearson' in self.correlation_matrices:
|
| 679 |
+
report["correlation_matrix_shape"] = self.correlation_matrices['pearson'].shape
|
| 680 |
+
|
| 681 |
+
if 'pearson' in self.high_correlation_pairs:
|
| 682 |
+
report["high_correlation_count"] = len(self.high_correlation_pairs['pearson'])
|
| 683 |
+
|
| 684 |
+
if self.vif_scores:
|
| 685 |
+
report["vif_summary"] = self.vif_scores.get('summary', {})
|
| 686 |
+
|
| 687 |
+
return report
|
data_loader/__init__.py
ADDED
|
File without changes
|
data_loader/data_loader.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 2: DATA LOADER
|
| 3 |
+
# ============================================
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
import traceback
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
from venv import logger
|
| 10 |
+
from config.config import Config, DataType
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
class DataLoader:
|
| 15 |
+
"""Class for loading and initial data processing"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config: Config):
|
| 18 |
+
"""
|
| 19 |
+
Initialise data loader
|
| 20 |
+
|
| 21 |
+
Parameters:
|
| 22 |
+
-----------
|
| 23 |
+
config : Config
|
| 24 |
+
Experiment configuration
|
| 25 |
+
"""
|
| 26 |
+
self.config = config
|
| 27 |
+
self.data = None
|
| 28 |
+
self.metadata = {}
|
| 29 |
+
self.data_hash = None
|
| 30 |
+
self.loading_time = None
|
| 31 |
+
self.data_types = {}
|
| 32 |
+
self.original_shape = None
|
| 33 |
+
|
| 34 |
+
def load_from_csv(
|
| 35 |
+
self,
|
| 36 |
+
data_path: Optional[str] = None,
|
| 37 |
+
parse_dates: List[str] = None,
|
| 38 |
+
date_format: str = None,
|
| 39 |
+
dtype: Dict = None,
|
| 40 |
+
**kwargs
|
| 41 |
+
) -> pd.DataFrame:
|
| 42 |
+
"""
|
| 43 |
+
Load data from CSV file
|
| 44 |
+
|
| 45 |
+
Parameters:
|
| 46 |
+
-----------
|
| 47 |
+
data_path : str, optional
|
| 48 |
+
Path to CSV file. If None, uses path from configuration.
|
| 49 |
+
parse_dates : List[str], optional
|
| 50 |
+
List of columns to parse as dates
|
| 51 |
+
date_format : str, optional
|
| 52 |
+
Date format
|
| 53 |
+
dtype : Dict, optional
|
| 54 |
+
Data types for columns
|
| 55 |
+
**kwargs : dict
|
| 56 |
+
Additional parameters for pd.read_csv
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
--------
|
| 60 |
+
pd.DataFrame
|
| 61 |
+
Loaded data
|
| 62 |
+
"""
|
| 63 |
+
logger.info("="*80)
|
| 64 |
+
logger.info("LOADING DATA FROM CSV")
|
| 65 |
+
logger.info("="*80)
|
| 66 |
+
|
| 67 |
+
start_time = datetime.now()
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
path = data_path or self.config.data_path
|
| 71 |
+
|
| 72 |
+
if parse_dates is None:
|
| 73 |
+
parse_dates = ['date']
|
| 74 |
+
|
| 75 |
+
# Load data
|
| 76 |
+
self.data = pd.read_csv(
|
| 77 |
+
path,
|
| 78 |
+
parse_dates=parse_dates,
|
| 79 |
+
dayfirst=False,
|
| 80 |
+
dtype=dtype,
|
| 81 |
+
**kwargs
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Convert dates if needed
|
| 85 |
+
for date_col in parse_dates:
|
| 86 |
+
if date_col in self.data.columns:
|
| 87 |
+
if date_format:
|
| 88 |
+
self.data[date_col] = pd.to_datetime(
|
| 89 |
+
self.data[date_col],
|
| 90 |
+
format=date_format,
|
| 91 |
+
errors='coerce'
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
self.data[date_col] = pd.to_datetime(
|
| 95 |
+
self.data[date_col],
|
| 96 |
+
errors='coerce'
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Save original shape
|
| 100 |
+
self.original_shape = self.data.shape
|
| 101 |
+
|
| 102 |
+
# Filter by years
|
| 103 |
+
if 'date' in self.data.columns:
|
| 104 |
+
mask = (self.data['date'].dt.year >= self.config.start_year) & \
|
| 105 |
+
(self.data['date'].dt.year <= self.config.end_year)
|
| 106 |
+
self.data = self.data.loc[mask].copy()
|
| 107 |
+
|
| 108 |
+
# Sort by date
|
| 109 |
+
if 'date' in self.data.columns:
|
| 110 |
+
self.data = self.data.sort_values('date').reset_index(drop=True)
|
| 111 |
+
# Set date as index
|
| 112 |
+
self.data.set_index('date', inplace=True)
|
| 113 |
+
|
| 114 |
+
# Calculate data hash
|
| 115 |
+
self.data_hash = self._calculate_data_hash()
|
| 116 |
+
|
| 117 |
+
# Analyse data types
|
| 118 |
+
self._analyse_data_types()
|
| 119 |
+
|
| 120 |
+
# Save metadata
|
| 121 |
+
self._save_metadata()
|
| 122 |
+
|
| 123 |
+
# Loading time
|
| 124 |
+
self.loading_time = (datetime.now() - start_time).total_seconds()
|
| 125 |
+
|
| 126 |
+
logger.info(f"✓ Loaded {len(self.data)} records, {len(self.data.columns)} columns")
|
| 127 |
+
logger.info(f" Period: {self.data.index.min()} - {self.data.index.max()}")
|
| 128 |
+
logger.info(f" Data types: {self.data_types}")
|
| 129 |
+
logger.info(f" Target variable: {self.config.target_column}")
|
| 130 |
+
logger.info(f" Loading time: {self.loading_time:.2f} sec")
|
| 131 |
+
|
| 132 |
+
return self.data
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"✗ Error loading data: {e}")
|
| 136 |
+
logger.error(traceback.format_exc())
|
| 137 |
+
raise
|
| 138 |
+
|
| 139 |
+
def create_synthetic_data(
|
| 140 |
+
self,
|
| 141 |
+
n_days: int = 365*21,
|
| 142 |
+
trend_strength: float = 0.01,
|
| 143 |
+
seasonal_amplitude: List[float] = None,
|
| 144 |
+
noise_std: float = 10,
|
| 145 |
+
include_exogenous: bool = True,
|
| 146 |
+
random_state: int = 42
|
| 147 |
+
) -> pd.DataFrame:
|
| 148 |
+
"""
|
| 149 |
+
Create synthetic data for testing
|
| 150 |
+
|
| 151 |
+
Parameters:
|
| 152 |
+
-----------
|
| 153 |
+
n_days : int
|
| 154 |
+
Number of days to generate
|
| 155 |
+
trend_strength : float
|
| 156 |
+
Trend strength
|
| 157 |
+
seasonal_amplitude : List[float], optional
|
| 158 |
+
Seasonal component amplitudes
|
| 159 |
+
noise_std : float
|
| 160 |
+
Noise standard deviation
|
| 161 |
+
include_exogenous : bool
|
| 162 |
+
Whether to include exogenous variables
|
| 163 |
+
random_state : int
|
| 164 |
+
Seed for reproducibility
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
--------
|
| 168 |
+
pd.DataFrame
|
| 169 |
+
Synthetic data
|
| 170 |
+
"""
|
| 171 |
+
logger.info("="*80)
|
| 172 |
+
logger.info("CREATING SYNTHETIC DATA")
|
| 173 |
+
logger.info("="*80)
|
| 174 |
+
|
| 175 |
+
if seasonal_amplitude is None:
|
| 176 |
+
seasonal_amplitude = [50, 30, 20]
|
| 177 |
+
|
| 178 |
+
np.random.seed(random_state)
|
| 179 |
+
|
| 180 |
+
# Generate dates
|
| 181 |
+
dates = pd.date_range(
|
| 182 |
+
start=f'{self.config.start_year}-01-01',
|
| 183 |
+
periods=n_days,
|
| 184 |
+
freq='D'
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
t = np.arange(n_days)
|
| 188 |
+
|
| 189 |
+
# Base components
|
| 190 |
+
trend = trend_strength * t
|
| 191 |
+
|
| 192 |
+
# Seasonal components
|
| 193 |
+
seasonal = 0
|
| 194 |
+
periods = [365, 30, 7] # yearly, monthly, weekly seasonality
|
| 195 |
+
for i, (period, amplitude) in enumerate(zip(periods, seasonal_amplitude)):
|
| 196 |
+
seasonal += amplitude * np.sin(2 * np.pi * t / period)
|
| 197 |
+
if i < len(seasonal_amplitude) - 1:
|
| 198 |
+
seasonal += 0.5 * amplitude * np.cos(4 * np.pi * t / period)
|
| 199 |
+
|
| 200 |
+
# Cyclical component (business cycles)
|
| 201 |
+
cycle = 20 * np.sin(2 * np.pi * t / (365*5)) # 5-year cycle
|
| 202 |
+
|
| 203 |
+
# Noise
|
| 204 |
+
noise = np.random.normal(0, noise_std, n_days)
|
| 205 |
+
|
| 206 |
+
# Generate target variable
|
| 207 |
+
raskhodvoda = 100 + trend + seasonal + cycle + noise
|
| 208 |
+
|
| 209 |
+
# Create DataFrame
|
| 210 |
+
self.data = pd.DataFrame(
|
| 211 |
+
index=dates,
|
| 212 |
+
data={'raskhodvoda': raskhodvoda}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Generate exogenous variables
|
| 216 |
+
if include_exogenous:
|
| 217 |
+
# Temperature with seasonality
|
| 218 |
+
tavg = 10 + 8 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 3, n_days)
|
| 219 |
+
tmin = tavg - 5 + np.random.normal(0, 2, n_days)
|
| 220 |
+
tmax = tavg + 5 + np.random.normal(0, 2, n_days)
|
| 221 |
+
|
| 222 |
+
# Water level with trend and seasonality
|
| 223 |
+
urovenvoda = 200 + 0.5 * t + 20 * np.sin(2 * np.pi * t / 365) + np.random.normal(0, 5, n_days)
|
| 224 |
+
|
| 225 |
+
# Add to DataFrame
|
| 226 |
+
self.data['tavg'] = tavg
|
| 227 |
+
self.data['tmin'] = tmin
|
| 228 |
+
self.data['tmax'] = tmax
|
| 229 |
+
self.data['urovenvoda'] = urovenvoda
|
| 230 |
+
|
| 231 |
+
# Add noisy lags
|
| 232 |
+
for lag in [1, 7, 30]:
|
| 233 |
+
self.data[f'tavg_lag_{lag}'] = self.data['tavg'].shift(lag) + np.random.normal(0, 1, n_days)
|
| 234 |
+
|
| 235 |
+
# Add missing values and outliers for testing
|
| 236 |
+
if n_days > 100:
|
| 237 |
+
# Missing values (5% of data)
|
| 238 |
+
mask_missing = np.random.random(n_days) < 0.05
|
| 239 |
+
self.data.loc[mask_missing, 'tavg'] = np.nan
|
| 240 |
+
|
| 241 |
+
# Outliers (1% of data)
|
| 242 |
+
mask_outliers = np.random.random(n_days) < 0.01
|
| 243 |
+
self.data.loc[mask_outliers, 'raskhodvoda'] *= 2
|
| 244 |
+
|
| 245 |
+
# Save metadata
|
| 246 |
+
self.metadata.update({
|
| 247 |
+
'is_synthetic': True,
|
| 248 |
+
'synthetic_params': {
|
| 249 |
+
'n_days': n_days,
|
| 250 |
+
'trend_strength': trend_strength,
|
| 251 |
+
'seasonal_amplitude': seasonal_amplitude,
|
| 252 |
+
'noise_std': noise_std,
|
| 253 |
+
'include_exogenous': include_exogenous,
|
| 254 |
+
'random_state': random_state
|
| 255 |
+
}
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
logger.info(f"✓ Created {len(self.data)} synthetic records")
|
| 259 |
+
logger.info(f" Columns: {list(self.data.columns)}")
|
| 260 |
+
|
| 261 |
+
return self.data
|
| 262 |
+
|
| 263 |
+
def _calculate_data_hash(self) -> str:
|
| 264 |
+
"""Calculate data hash for tracking changes"""
|
| 265 |
+
if self.data is None:
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
# Use hash of first 1000 rows and metadata
|
| 269 |
+
sample = self.data.head(1000).to_string().encode()
|
| 270 |
+
return hashlib.md5(sample).hexdigest()
|
| 271 |
+
|
| 272 |
+
def _analyse_data_types(self) -> None:
|
| 273 |
+
"""Analyse data types in DataFrame"""
|
| 274 |
+
if self.data is None:
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
for col in self.data.columns:
|
| 278 |
+
dtype = str(self.data[col].dtype)
|
| 279 |
+
|
| 280 |
+
if 'datetime' in dtype:
|
| 281 |
+
self.data_types[col] = DataType.TEMPORAL.value
|
| 282 |
+
elif 'int' in dtype or 'float' in dtype:
|
| 283 |
+
self.data_types[col] = DataType.NUMERIC.value
|
| 284 |
+
elif 'object' in dtype or 'category' in dtype:
|
| 285 |
+
# Check if categorical
|
| 286 |
+
unique_ratio = self.data[col].nunique() / len(self.data)
|
| 287 |
+
if unique_ratio < 0.1: # Less than 10% unique values
|
| 288 |
+
self.data_types[col] = DataType.CATEGORICAL.value
|
| 289 |
+
else:
|
| 290 |
+
self.data_types[col] = DataType.TEXT.value
|
| 291 |
+
else:
|
| 292 |
+
self.data_types[col] = 'unknown'
|
| 293 |
+
|
| 294 |
+
def _save_metadata(self) -> None:
|
| 295 |
+
"""Save data metadata"""
|
| 296 |
+
if self.data is None:
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
# Basic metadata
|
| 300 |
+
self.metadata.update({
|
| 301 |
+
'original_shape': list(self.original_shape) if self.original_shape else [],
|
| 302 |
+
'current_shape': list(self.data.shape),
|
| 303 |
+
'columns': list(self.data.columns),
|
| 304 |
+
'data_types': self.data_types,
|
| 305 |
+
'date_range': {
|
| 306 |
+
'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
|
| 307 |
+
'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
|
| 308 |
+
},
|
| 309 |
+
'data_hash': self.data_hash,
|
| 310 |
+
'loading_time': self.loading_time
|
| 311 |
+
})
|
| 312 |
+
|
| 313 |
+
# Statistics for numeric columns
|
| 314 |
+
numeric_cols = self.data.select_dtypes(include=[np.number]).columns
|
| 315 |
+
if len(numeric_cols) > 0:
|
| 316 |
+
stats = self.data[numeric_cols].describe().to_dict()
|
| 317 |
+
# Add additional statistics
|
| 318 |
+
for col in numeric_cols:
|
| 319 |
+
stats[col]['skewness'] = float(self.data[col].skew())
|
| 320 |
+
stats[col]['kurtosis'] = float(self.data[col].kurtosis())
|
| 321 |
+
stats[col]['cv'] = float(self.data[col].std() / self.data[col].mean()) if self.data[col].mean() != 0 else np.nan
|
| 322 |
+
|
| 323 |
+
self.metadata['numeric_statistics'] = stats
|
| 324 |
+
|
| 325 |
+
# Missing values information
|
| 326 |
+
missing_info = {
|
| 327 |
+
'total_missing': int(self.data.isnull().sum().sum()),
|
| 328 |
+
'missing_by_column': self.data.isnull().sum().to_dict(),
|
| 329 |
+
'missing_percentage': (self.data.isnull().sum() / len(self.data) * 100).to_dict(),
|
| 330 |
+
'rows_with_missing': int(self.data.isnull().any(axis=1).sum()),
|
| 331 |
+
'columns_with_missing': self.data.columns[self.data.isnull().any()].tolist()
|
| 332 |
+
}
|
| 333 |
+
self.metadata['missing_info'] = missing_info
|
| 334 |
+
|
| 335 |
+
def get_data_info(self) -> Dict:
|
| 336 |
+
"""Get information about data"""
|
| 337 |
+
if self.data is None:
|
| 338 |
+
return {}
|
| 339 |
+
|
| 340 |
+
info = {
|
| 341 |
+
'shape': list(self.data.shape),
|
| 342 |
+
'columns': list(self.data.columns),
|
| 343 |
+
'data_types': self.data_types,
|
| 344 |
+
'date_range': {
|
| 345 |
+
'min': self.data.index.min().strftime('%Y-%m-%d') if pd.notnull(self.data.index.min()) else None,
|
| 346 |
+
'max': self.data.index.max().strftime('%Y-%m-%d') if pd.notnull(self.data.index.max()) else None
|
| 347 |
+
},
|
| 348 |
+
'target_column': self.config.target_column,
|
| 349 |
+
'numeric_columns': self.data.select_dtypes(include=[np.number]).columns.tolist(),
|
| 350 |
+
'categorical_columns': [col for col, dtype in self.data_types.items()
|
| 351 |
+
if dtype == DataType.CATEGORICAL.value],
|
| 352 |
+
'missing_info': self.metadata.get('missing_info', {})
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
return info
|
| 356 |
+
|
| 357 |
+
def save_raw_data_info(self) -> None:
|
| 358 |
+
"""Save raw data information"""
|
| 359 |
+
if self.data is None:
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
info_path = f'{self.config.results_dir}/reports/raw_data_info.json'
|
| 363 |
+
|
| 364 |
+
# Custom JSON encoder for handling numpy types
|
| 365 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 366 |
+
def default(self, obj):
|
| 367 |
+
if isinstance(obj, (np.integer, np.floating)):
|
| 368 |
+
if np.isnan(obj):
|
| 369 |
+
return None
|
| 370 |
+
return float(obj)
|
| 371 |
+
elif isinstance(obj, np.bool_):
|
| 372 |
+
return bool(obj)
|
| 373 |
+
elif isinstance(obj, np.ndarray):
|
| 374 |
+
return obj.tolist()
|
| 375 |
+
elif isinstance(obj, pd.Timestamp):
|
| 376 |
+
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
| 377 |
+
elif isinstance(obj, pd.Period):
|
| 378 |
+
return str(obj)
|
| 379 |
+
return super().default(obj)
|
| 380 |
+
|
| 381 |
+
with open(info_path, 'w', encoding='utf-8') as f:
|
| 382 |
+
json.dump(self.metadata, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
|
| 383 |
+
|
| 384 |
+
logger.info(f"✓ Raw data information saved: {info_path}")
|
| 385 |
+
|
| 386 |
+
def resample_data(
|
| 387 |
+
self,
|
| 388 |
+
freq: str = None,
|
| 389 |
+
method: str = 'mean'
|
| 390 |
+
) -> pd.DataFrame:
|
| 391 |
+
"""
|
| 392 |
+
Resample time series data
|
| 393 |
+
|
| 394 |
+
Parameters:
|
| 395 |
+
-----------
|
| 396 |
+
freq : str, optional
|
| 397 |
+
New frequency (e.g., 'D', 'W', 'M')
|
| 398 |
+
method : str
|
| 399 |
+
Aggregation method: 'mean', 'sum', 'last', 'first'
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
--------
|
| 403 |
+
pd.DataFrame
|
| 404 |
+
Resampled data
|
| 405 |
+
"""
|
| 406 |
+
if self.data is None:
|
| 407 |
+
logger.warning("Data not loaded")
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
freq = freq or self.config.freq
|
| 411 |
+
|
| 412 |
+
# Check if index is datetime
|
| 413 |
+
if not isinstance(self.data.index, pd.DatetimeIndex):
|
| 414 |
+
logger.error("Data index is not DatetimeIndex")
|
| 415 |
+
return self.data
|
| 416 |
+
|
| 417 |
+
# Aggregation methods
|
| 418 |
+
agg_methods = {
|
| 419 |
+
'mean': np.mean,
|
| 420 |
+
'sum': np.sum,
|
| 421 |
+
'last': lambda x: x.iloc[-1],
|
| 422 |
+
'first': lambda x: x.iloc[0],
|
| 423 |
+
'min': np.min,
|
| 424 |
+
'max': np.max,
|
| 425 |
+
'median': np.median
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
if method not in agg_methods:
|
| 429 |
+
logger.warning(f"Method {method} not supported, using mean")
|
| 430 |
+
method = 'mean'
|
| 431 |
+
|
| 432 |
+
# Resampling
|
| 433 |
+
try:
|
| 434 |
+
if method == 'last':
|
| 435 |
+
resampled_data = self.data.resample(freq).last()
|
| 436 |
+
elif method == 'first':
|
| 437 |
+
resampled_data = self.data.resample(freq).first()
|
| 438 |
+
else:
|
| 439 |
+
resampled_data = self.data.resample(freq).agg(agg_methods[method])
|
| 440 |
+
|
| 441 |
+
logger.info(f"Data resampled to frequency {freq}, method {method}")
|
| 442 |
+
logger.info(f"Size before: {len(self.data)}, after: {len(resampled_data)}")
|
| 443 |
+
|
| 444 |
+
self.data = resampled_data
|
| 445 |
+
return self.data
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
logger.error(f"Error during resampling: {e}")
|
| 449 |
+
return self.data
|
| 450 |
+
|
| 451 |
+
def detect_frequency(self) -> str:
|
| 452 |
+
"""
|
| 453 |
+
Automatically detect data frequency
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
--------
|
| 457 |
+
str
|
| 458 |
+
Detected data frequency
|
| 459 |
+
"""
|
| 460 |
+
if self.data is None or len(self.data) < 2:
|
| 461 |
+
return 'unknown'
|
| 462 |
+
|
| 463 |
+
if not isinstance(self.data.index, pd.DatetimeIndex):
|
| 464 |
+
return 'irregular'
|
| 465 |
+
|
| 466 |
+
# Calculate differences between timestamps
|
| 467 |
+
diffs = pd.Series(self.data.index).diff().dropna()
|
| 468 |
+
|
| 469 |
+
if len(diffs) == 0:
|
| 470 |
+
return 'unknown'
|
| 471 |
+
|
| 472 |
+
# Most frequent difference
|
| 473 |
+
mode_diff = diffs.mode().iloc[0] if not diffs.mode().empty else diffs.iloc[0]
|
| 474 |
+
|
| 475 |
+
# Determine frequency
|
| 476 |
+
if mode_diff < pd.Timedelta('1 hour'):
|
| 477 |
+
return 'H' # Hourly
|
| 478 |
+
elif mode_diff < pd.Timedelta('1 day'):
|
| 479 |
+
return 'D' # Daily
|
| 480 |
+
elif mode_diff < pd.Timedelta('7 days'):
|
| 481 |
+
return 'W' # Weekly
|
| 482 |
+
elif mode_diff < pd.Timedelta('30 days'):
|
| 483 |
+
return 'M' # Monthly
|
| 484 |
+
elif mode_diff < pd.Timedelta('90 days'):
|
| 485 |
+
return 'Q' # Quarterly
|
| 486 |
+
else:
|
| 487 |
+
return 'Y' # Yearly
|
decomposition/__init__.py
ADDED
|
File without changes
|
decomposition/decomposer.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 7: TIME SERIES DECOMPOSITION
|
| 3 |
+
# ============================================
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
from venv import logger
|
| 7 |
+
|
| 8 |
+
from config.config import Config
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import numpy as np
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import statsmodels.api as sm
|
| 14 |
+
from scipy import stats
|
| 15 |
+
from statsmodels.tsa.seasonal import seasonal_decompose, STL
|
| 16 |
+
from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
|
| 17 |
+
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
|
| 18 |
+
from statsmodels.stats.diagnostic import acorr_ljungbox
|
| 19 |
+
from statsmodels.tsa.statespace.sarimax import SARIMAX
|
| 20 |
+
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TimeSeriesDecomposer:
|
| 24 |
+
"""Class for time series decomposition"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config: Config):
|
| 27 |
+
"""
|
| 28 |
+
Initialise decomposer
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
-----------
|
| 32 |
+
config : Config
|
| 33 |
+
Experiment configuration
|
| 34 |
+
"""
|
| 35 |
+
self.config = config
|
| 36 |
+
self.decomposition_results = {}
|
| 37 |
+
self.decomposition_models = {}
|
| 38 |
+
self.seasonal_periods = {}
|
| 39 |
+
|
| 40 |
+
def decompose(
|
| 41 |
+
self,
|
| 42 |
+
data: pd.DataFrame,
|
| 43 |
+
target_col: Optional[str] = None,
|
| 44 |
+
method: str = 'stl',
|
| 45 |
+
period: Optional[int] = None,
|
| 46 |
+
**kwargs
|
| 47 |
+
) -> Dict:
|
| 48 |
+
"""
|
| 49 |
+
Decompose time series into components
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
-----------
|
| 53 |
+
data : pd.DataFrame
|
| 54 |
+
Input data
|
| 55 |
+
target_col : str, optional
|
| 56 |
+
Target variable. If None, uses configuration value.
|
| 57 |
+
method : str
|
| 58 |
+
Decomposition model: 'stl', 'seasonal_decompose', 'mstl', 'naive'
|
| 59 |
+
period : int, optional
|
| 60 |
+
Seasonality period. If None, uses configuration value.
|
| 61 |
+
**kwargs : dict
|
| 62 |
+
Additional parameters for method
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
--------
|
| 66 |
+
Dict
|
| 67 |
+
Decomposition results
|
| 68 |
+
"""
|
| 69 |
+
logger.info("\n" + "="*80)
|
| 70 |
+
logger.info("TIME SERIES DECOMPOSITION")
|
| 71 |
+
logger.info("="*80)
|
| 72 |
+
|
| 73 |
+
target_col = target_col or self.config.target_column
|
| 74 |
+
period = period or self.config.seasonal_period
|
| 75 |
+
|
| 76 |
+
if target_col not in data.columns:
|
| 77 |
+
logger.error(f"Target variable '{target_col}' not found")
|
| 78 |
+
return {}
|
| 79 |
+
|
| 80 |
+
# Set date as index if not set
|
| 81 |
+
if not isinstance(data.index, pd.DatetimeIndex):
|
| 82 |
+
if 'date' in data.columns:
|
| 83 |
+
data = data.set_index('date')
|
| 84 |
+
else:
|
| 85 |
+
logger.error("DatetimeIndex required for decomposition")
|
| 86 |
+
return {}
|
| 87 |
+
|
| 88 |
+
series = data[target_col]
|
| 89 |
+
|
| 90 |
+
# Automatic seasonality period detection
|
| 91 |
+
if period is None or period == 'auto':
|
| 92 |
+
period = self._detect_seasonal_period(series)
|
| 93 |
+
logger.info(f"Automatically detected seasonality period: {period}")
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
decomposition_result = None
|
| 97 |
+
|
| 98 |
+
if method == 'stl':
|
| 99 |
+
decomposition_result = self._stl_decomposition(series, period, **kwargs)
|
| 100 |
+
elif method == 'seasonal_decompose':
|
| 101 |
+
decomposition_result = self._seasonal_decompose(series, period, **kwargs)
|
| 102 |
+
elif method == 'mstl':
|
| 103 |
+
decomposition_result = self._mstl_decomposition(series, **kwargs)
|
| 104 |
+
elif method == 'naive':
|
| 105 |
+
decomposition_result = self._naive_decomposition(series, period, **kwargs)
|
| 106 |
+
else:
|
| 107 |
+
logger.warning(f"Method {method} not supported, using STL")
|
| 108 |
+
decomposition_result = self._stl_decomposition(series, period, **kwargs)
|
| 109 |
+
|
| 110 |
+
if decomposition_result is None:
|
| 111 |
+
logger.error("Decomposition failed")
|
| 112 |
+
return {}
|
| 113 |
+
|
| 114 |
+
# Analyse residuals
|
| 115 |
+
residuals_info = self._analyse_residuals(decomposition_result.get('residual', None))
|
| 116 |
+
|
| 117 |
+
# Analyse seasonality
|
| 118 |
+
seasonal_info = self._analyse_seasonality(
|
| 119 |
+
decomposition_result.get('seasonal', None),
|
| 120 |
+
period
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Save results
|
| 124 |
+
self.decomposition_results[target_col] = {
|
| 125 |
+
'method': method,
|
| 126 |
+
'period': period,
|
| 127 |
+
'residuals_analysis': residuals_info,
|
| 128 |
+
'seasonality_analysis': seasonal_info,
|
| 129 |
+
'components_present': list(decomposition_result.keys()),
|
| 130 |
+
'decomposition_stats': {
|
| 131 |
+
'trend_strength': self._calculate_trend_strength(
|
| 132 |
+
decomposition_result.get('trend', None),
|
| 133 |
+
decomposition_result.get('residual', None)
|
| 134 |
+
),
|
| 135 |
+
'seasonal_strength': self._calculate_seasonal_strength(
|
| 136 |
+
decomposition_result.get('seasonal', None),
|
| 137 |
+
decomposition_result.get('residual', None)
|
| 138 |
+
)
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# Visualisation
|
| 143 |
+
if self.config.save_plots:
|
| 144 |
+
self._plot_decomposition(data, target_col, decomposition_result, method, period)
|
| 145 |
+
|
| 146 |
+
# Additional visualisation
|
| 147 |
+
if residuals_info:
|
| 148 |
+
self._plot_residuals_analysis(decomposition_result.get('residual', None), target_col)
|
| 149 |
+
|
| 150 |
+
return self.decomposition_results[target_col]
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Error during decomposition: {e}")
|
| 154 |
+
logger.error(traceback.format_exc())
|
| 155 |
+
return {}
|
| 156 |
+
|
| 157 |
+
def _detect_seasonal_period(self, series: pd.Series) -> int:
|
| 158 |
+
"""Automatic seasonality period detection"""
|
| 159 |
+
if len(series) < 100:
|
| 160 |
+
return self.config.seasonal_period
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# Use autocorrelation to determine period
|
| 164 |
+
acf_values = acf(series.dropna(), nlags=min(500, len(series)//2))
|
| 165 |
+
|
| 166 |
+
# Find peaks in autocorrelation
|
| 167 |
+
peaks = []
|
| 168 |
+
for i in range(1, len(acf_values)-1):
|
| 169 |
+
if acf_values[i] > acf_values[i-1] and acf_values[i] > acf_values[i+1]:
|
| 170 |
+
if acf_values[i] > 0.3: # Significance threshold
|
| 171 |
+
peaks.append(i)
|
| 172 |
+
|
| 173 |
+
if peaks:
|
| 174 |
+
# Take most significant period
|
| 175 |
+
dominant_period = peaks[0]
|
| 176 |
+
|
| 177 |
+
# Check for multiple periods
|
| 178 |
+
for period in [7, 30, 90, 365]:
|
| 179 |
+
if abs(dominant_period - period) <= 2:
|
| 180 |
+
return period
|
| 181 |
+
|
| 182 |
+
return dominant_period
|
| 183 |
+
|
| 184 |
+
return self.config.seasonal_period
|
| 185 |
+
|
| 186 |
+
except:
|
| 187 |
+
return self.config.seasonal_period
|
| 188 |
+
|
| 189 |
+
def _stl_decomposition(
|
| 190 |
+
self,
|
| 191 |
+
series: pd.Series,
|
| 192 |
+
period: int,
|
| 193 |
+
**kwargs
|
| 194 |
+
) -> Optional[Dict]:
|
| 195 |
+
"""STL decomposition"""
|
| 196 |
+
try:
|
| 197 |
+
if len(series) < 2 * period:
|
| 198 |
+
logger.warning(f"Insufficient data for STL decomposition with period {period}")
|
| 199 |
+
return self._seasonal_decompose(series, period, **kwargs)
|
| 200 |
+
|
| 201 |
+
# STL decomposition
|
| 202 |
+
stl = STL(
|
| 203 |
+
series,
|
| 204 |
+
period=period,
|
| 205 |
+
seasonal=kwargs.get('seasonal', 7),
|
| 206 |
+
trend=kwargs.get('trend', None),
|
| 207 |
+
robust=kwargs.get('robust', True),
|
| 208 |
+
seasonal_deg=kwargs.get('seasonal_deg', 1),
|
| 209 |
+
trend_deg=kwargs.get('trend_deg', 1),
|
| 210 |
+
low_pass_deg=kwargs.get('low_pass_deg', 1)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
result = stl.fit()
|
| 214 |
+
|
| 215 |
+
return {
|
| 216 |
+
'trend': result.trend,
|
| 217 |
+
'seasonal': result.seasonal,
|
| 218 |
+
'residual': result.resid,
|
| 219 |
+
'observed': series
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning(f"STL decomposition failed: {e}")
|
| 224 |
+
return self._seasonal_decompose(series, period, **kwargs)
|
| 225 |
+
|
| 226 |
+
def _seasonal_decompose(
|
| 227 |
+
self,
|
| 228 |
+
series: pd.Series,
|
| 229 |
+
period: int,
|
| 230 |
+
**kwargs
|
| 231 |
+
) -> Optional[Dict]:
|
| 232 |
+
"""Seasonal decomposition from statsmodels"""
|
| 233 |
+
try:
|
| 234 |
+
model = kwargs.get('model', 'additive')
|
| 235 |
+
|
| 236 |
+
if len(series) < 2 * period:
|
| 237 |
+
# Reduce period if insufficient data
|
| 238 |
+
period = max(7, len(series) // 4)
|
| 239 |
+
|
| 240 |
+
decomposition = seasonal_decompose(
|
| 241 |
+
series,
|
| 242 |
+
model=model,
|
| 243 |
+
period=period,
|
| 244 |
+
extrapolate_trend=kwargs.get('extrapolate_trend', 'freq'),
|
| 245 |
+
two_sided=kwargs.get('two_sided', True)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
'trend': decomposition.trend,
|
| 250 |
+
'seasonal': decomposition.seasonal,
|
| 251 |
+
'residual': decomposition.resid,
|
| 252 |
+
'observed': series
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.warning(f"Seasonal decompose failed: {e}")
|
| 257 |
+
return self._naive_decomposition(series, period, **kwargs)
|
| 258 |
+
|
| 259 |
+
def _mstl_decomposition(
|
| 260 |
+
self,
|
| 261 |
+
series: pd.Series,
|
| 262 |
+
**kwargs
|
| 263 |
+
) -> Optional[Dict]:
|
| 264 |
+
"""Multi-seasonal decomposition (simplified)"""
|
| 265 |
+
try:
|
| 266 |
+
# Simplified MSTL version
|
| 267 |
+
periods = kwargs.get('periods', [7, 365])
|
| 268 |
+
|
| 269 |
+
result = {
|
| 270 |
+
'observed': series,
|
| 271 |
+
'trend': None,
|
| 272 |
+
'seasonal': pd.Series(0, index=series.index),
|
| 273 |
+
'residual': series.copy()
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
# Sequentially remove seasonal components
|
| 277 |
+
for period in periods:
|
| 278 |
+
if len(series) >= 2 * period:
|
| 279 |
+
try:
|
| 280 |
+
decomp = seasonal_decompose(
|
| 281 |
+
result['residual'],
|
| 282 |
+
model='additive',
|
| 283 |
+
period=period,
|
| 284 |
+
extrapolate_trend='freq'
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if result['trend'] is None:
|
| 288 |
+
result['trend'] = decomp.trend
|
| 289 |
+
|
| 290 |
+
result['seasonal'] = result['seasonal'] + decomp.seasonal
|
| 291 |
+
result['residual'] = decomp.resid
|
| 292 |
+
except:
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
if result['trend'] is None:
|
| 296 |
+
result['trend'] = series.rolling(window=min(365, len(series)//4), center=True).mean()
|
| 297 |
+
|
| 298 |
+
return result
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.warning(f"MSTL decomposition failed: {e}")
|
| 302 |
+
return self._seasonal_decompose(series, 365, **kwargs)
|
| 303 |
+
|
| 304 |
+
def _naive_decomposition(
|
| 305 |
+
self,
|
| 306 |
+
series: pd.Series,
|
| 307 |
+
period: int,
|
| 308 |
+
**kwargs
|
| 309 |
+
) -> Optional[Dict]:
|
| 310 |
+
"""Naive decomposition"""
|
| 311 |
+
try:
|
| 312 |
+
# Simple decomposition using moving averages
|
| 313 |
+
trend = series.rolling(
|
| 314 |
+
window=min(period, len(series)//4),
|
| 315 |
+
center=True,
|
| 316 |
+
min_periods=1
|
| 317 |
+
).mean()
|
| 318 |
+
|
| 319 |
+
# Seasonal component
|
| 320 |
+
if period > 1:
|
| 321 |
+
# Average by seasons
|
| 322 |
+
seasonal = series.groupby(series.index.dayofyear if period == 365 else
|
| 323 |
+
series.index.dayofweek if period == 7 else
|
| 324 |
+
series.index.month).transform('mean')
|
| 325 |
+
seasonal = seasonal - seasonal.mean()
|
| 326 |
+
else:
|
| 327 |
+
seasonal = pd.Series(0, index=series.index)
|
| 328 |
+
|
| 329 |
+
residual = series - trend - seasonal
|
| 330 |
+
|
| 331 |
+
return {
|
| 332 |
+
'trend': trend,
|
| 333 |
+
'seasonal': seasonal,
|
| 334 |
+
'residual': residual,
|
| 335 |
+
'observed': series
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Naive decomposition failed: {e}")
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
def _analyse_residuals(self, residuals) -> Dict:
|
| 343 |
+
"""Analyse decomposition residuals"""
|
| 344 |
+
if residuals is None:
|
| 345 |
+
return {}
|
| 346 |
+
|
| 347 |
+
residuals_clean = residuals.dropna()
|
| 348 |
+
|
| 349 |
+
if len(residuals_clean) == 0:
|
| 350 |
+
return {}
|
| 351 |
+
|
| 352 |
+
stats_info = {
|
| 353 |
+
'mean': float(residuals_clean.mean()),
|
| 354 |
+
'std': float(residuals_clean.std()),
|
| 355 |
+
'skewness': float(residuals_clean.skew()),
|
| 356 |
+
'kurtosis': float(residuals_clean.kurtosis()),
|
| 357 |
+
'min': float(residuals_clean.min()),
|
| 358 |
+
'max': float(residuals_clean.max()),
|
| 359 |
+
'mad': float((residuals_clean - residuals_clean.mean()).abs().mean()),
|
| 360 |
+
'normality_tests': {},
|
| 361 |
+
'autocorrelation_tests': {}
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# Normality test
|
| 365 |
+
if len(residuals_clean) > 3:
|
| 366 |
+
try:
|
| 367 |
+
# Shapiro-Wilk test
|
| 368 |
+
shapiro_stat, shapiro_p = stats.shapiro(residuals_clean.iloc[:5000])
|
| 369 |
+
stats_info['normality_tests']['shapiro_wilk'] = {
|
| 370 |
+
'statistic': float(shapiro_stat),
|
| 371 |
+
'pvalue': float(shapiro_p),
|
| 372 |
+
'is_normal': shapiro_p > 0.05
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
# Anderson-Darling test
|
| 376 |
+
anderson_result = stats.anderson(residuals_clean, dist='norm')
|
| 377 |
+
stats_info['normality_tests']['anderson_darling'] = {
|
| 378 |
+
'statistic': float(anderson_result.statistic),
|
| 379 |
+
'critical_values': {str(level): float(value)
|
| 380 |
+
for level, value in zip(anderson_result.significance_level,
|
| 381 |
+
anderson_result.critical_values)},
|
| 382 |
+
'is_normal': anderson_result.statistic < anderson_result.critical_values[2] # At 5% level
|
| 383 |
+
}
|
| 384 |
+
except:
|
| 385 |
+
stats_info['normality_tests']['error'] = 'not enough data or calculation error'
|
| 386 |
+
|
| 387 |
+
# Autocorrelation test
|
| 388 |
+
try:
|
| 389 |
+
# Ljung-Box test
|
| 390 |
+
lb_test = acorr_ljungbox(residuals_clean, lags=[10, 20, 30], return_df=True)
|
| 391 |
+
|
| 392 |
+
autocorr_info = {}
|
| 393 |
+
for idx, row in lb_test.iterrows():
|
| 394 |
+
autocorr_info[f'lag_{int(row.name)}'] = {
|
| 395 |
+
'statistic': float(row['lb_stat']),
|
| 396 |
+
'pvalue': float(row['lb_pvalue']),
|
| 397 |
+
'has_autocorrelation': row['lb_pvalue'] < 0.05
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
stats_info['autocorrelation_tests']['ljung_box'] = autocorr_info
|
| 401 |
+
|
| 402 |
+
# Durbin-Watson test
|
| 403 |
+
try:
|
| 404 |
+
dw_stat = sm.stats.stattools.durbin_watson(residuals_clean)
|
| 405 |
+
stats_info['autocorrelation_tests']['durbin_watson'] = {
|
| 406 |
+
'statistic': float(dw_stat),
|
| 407 |
+
'interpretation': 'no autocorrelation' if 1.5 < dw_stat < 2.5 else
|
| 408 |
+
'positive autocorrelation' if dw_stat < 1.5 else
|
| 409 |
+
'negative autocorrelation'
|
| 410 |
+
}
|
| 411 |
+
except:
|
| 412 |
+
pass
|
| 413 |
+
|
| 414 |
+
except:
|
| 415 |
+
stats_info['autocorrelation_tests']['error'] = 'calculation error'
|
| 416 |
+
|
| 417 |
+
# Heteroskedasticity test
|
| 418 |
+
try:
|
| 419 |
+
# ARCH test
|
| 420 |
+
from statsmodels.stats.diagnostic import het_arch
|
| 421 |
+
arch_test = het_arch(residuals_clean)
|
| 422 |
+
stats_info['heteroskedasticity_tests'] = {
|
| 423 |
+
'arch': {
|
| 424 |
+
'statistic': float(arch_test[0]),
|
| 425 |
+
'pvalue': float(arch_test[1]),
|
| 426 |
+
'is_homoskedastic': arch_test[1] > 0.05
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
except:
|
| 430 |
+
pass
|
| 431 |
+
|
| 432 |
+
return stats_info
|
| 433 |
+
|
| 434 |
+
def _analyse_seasonality(self, seasonal_component, period: int) -> Dict:
|
| 435 |
+
"""Analyse seasonal component"""
|
| 436 |
+
if seasonal_component is None:
|
| 437 |
+
return {}
|
| 438 |
+
|
| 439 |
+
seasonal_clean = seasonal_component.dropna()
|
| 440 |
+
|
| 441 |
+
if len(seasonal_clean) == 0:
|
| 442 |
+
return {}
|
| 443 |
+
|
| 444 |
+
analysis = {
|
| 445 |
+
'period': period,
|
| 446 |
+
'amplitude': float(seasonal_clean.max() - seasonal_clean.min()),
|
| 447 |
+
'mean_amplitude': float(seasonal_clean.abs().mean()),
|
| 448 |
+
'seasonal_strength': float(seasonal_clean.std()),
|
| 449 |
+
'periodicity_check': {}
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
# Check periodicity via autocorrelation
|
| 453 |
+
if len(seasonal_clean) > period * 2:
|
| 454 |
+
try:
|
| 455 |
+
acf_values = acf(seasonal_clean, nlags=min(period * 3, len(seasonal_clean)//2))
|
| 456 |
+
|
| 457 |
+
# Look for peaks at expected lags
|
| 458 |
+
expected_lags = [period, period*2]
|
| 459 |
+
peaks_found = []
|
| 460 |
+
|
| 461 |
+
for lag in expected_lags:
|
| 462 |
+
if lag < len(acf_values):
|
| 463 |
+
if acf_values[lag] > 0.5: # Strong autocorrelation at period
|
| 464 |
+
peaks_found.append({
|
| 465 |
+
'lag': lag,
|
| 466 |
+
'autocorrelation': float(acf_values[lag]),
|
| 467 |
+
'is_significant': True
|
| 468 |
+
})
|
| 469 |
+
|
| 470 |
+
analysis['periodicity_check']['autocorrelation_peaks'] = peaks_found
|
| 471 |
+
analysis['periodicity_check']['is_periodic'] = len(peaks_found) > 0
|
| 472 |
+
except:
|
| 473 |
+
pass
|
| 474 |
+
|
| 475 |
+
# Seasonality pattern analysis
|
| 476 |
+
if isinstance(seasonal_clean.index, pd.DatetimeIndex):
|
| 477 |
+
try:
|
| 478 |
+
# Group by months/week days
|
| 479 |
+
if period == 12 or period == 365:
|
| 480 |
+
# Monthly seasonality
|
| 481 |
+
monthly_seasonal = seasonal_clean.groupby(seasonal_clean.index.month).mean()
|
| 482 |
+
analysis['monthly_pattern'] = monthly_seasonal.to_dict()
|
| 483 |
+
|
| 484 |
+
if period == 7 or period == 365:
|
| 485 |
+
# Daily seasonality
|
| 486 |
+
daily_seasonal = seasonal_clean.groupby(seasonal_clean.index.dayofweek).mean()
|
| 487 |
+
analysis['daily_pattern'] = daily_seasonal.to_dict()
|
| 488 |
+
except:
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
return analysis
|
| 492 |
+
|
| 493 |
+
def _calculate_trend_strength(self, trend, residual) -> float:
|
| 494 |
+
"""Calculate trend strength"""
|
| 495 |
+
if trend is None or residual is None:
|
| 496 |
+
return 0.0
|
| 497 |
+
|
| 498 |
+
trend_clean = trend.dropna()
|
| 499 |
+
residual_clean = residual.dropna()
|
| 500 |
+
|
| 501 |
+
if len(trend_clean) == 0 or len(residual_clean) == 0:
|
| 502 |
+
return 0.0
|
| 503 |
+
|
| 504 |
+
# Trend strength = 1 - Var(residual) / Var(trend + residual)
|
| 505 |
+
try:
|
| 506 |
+
var_total = np.var(trend_clean + residual_clean)
|
| 507 |
+
if var_total > 0:
|
| 508 |
+
trend_strength = 1 - np.var(residual_clean) / var_total
|
| 509 |
+
return max(0.0, min(1.0, float(trend_strength)))
|
| 510 |
+
except:
|
| 511 |
+
pass
|
| 512 |
+
|
| 513 |
+
return 0.0
|
| 514 |
+
|
| 515 |
+
def _calculate_seasonal_strength(self, seasonal, residual) -> float:
|
| 516 |
+
"""Calculate seasonality strength"""
|
| 517 |
+
if seasonal is None or residual is None:
|
| 518 |
+
return 0.0
|
| 519 |
+
|
| 520 |
+
seasonal_clean = seasonal.dropna()
|
| 521 |
+
residual_clean = residual.dropna()
|
| 522 |
+
|
| 523 |
+
if len(seasonal_clean) == 0 or len(residual_clean) == 0:
|
| 524 |
+
return 0.0
|
| 525 |
+
|
| 526 |
+
# Seasonality strength = 1 - Var(residual) / Var(seasonal + residual)
|
| 527 |
+
try:
|
| 528 |
+
var_total = np.var(seasonal_clean + residual_clean)
|
| 529 |
+
if var_total > 0:
|
| 530 |
+
seasonal_strength = 1 - np.var(residual_clean) / var_total
|
| 531 |
+
return max(0.0, min(1.0, float(seasonal_strength)))
|
| 532 |
+
except:
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
return 0.0
|
| 536 |
+
|
| 537 |
+
def _plot_decomposition(
|
| 538 |
+
self,
|
| 539 |
+
data: pd.DataFrame,
|
| 540 |
+
target_col: str,
|
| 541 |
+
decomposition: Dict,
|
| 542 |
+
method: str,
|
| 543 |
+
period: int
|
| 544 |
+
) -> None:
|
| 545 |
+
"""Visualise decomposition"""
|
| 546 |
+
fig, axes = plt.subplots(4, 1, figsize=(14, 12))
|
| 547 |
+
|
| 548 |
+
# Original series
|
| 549 |
+
axes[0].plot(decomposition.get('observed', pd.Series()))
|
| 550 |
+
axes[0].set_ylabel('Observed')
|
| 551 |
+
axes[0].set_title(f'Time Series Decomposition: {target_col} ({method}, period={period})')
|
| 552 |
+
axes[0].grid(True, alpha=0.3)
|
| 553 |
+
|
| 554 |
+
# Trend
|
| 555 |
+
if 'trend' in decomposition and decomposition['trend'] is not None:
|
| 556 |
+
axes[1].plot(decomposition['trend'])
|
| 557 |
+
axes[1].set_ylabel('Trend')
|
| 558 |
+
axes[1].grid(True, alpha=0.3)
|
| 559 |
+
|
| 560 |
+
# Seasonality
|
| 561 |
+
if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
|
| 562 |
+
axes[2].plot(decomposition['seasonal'])
|
| 563 |
+
axes[2].set_ylabel('Seasonality')
|
| 564 |
+
axes[2].grid(True, alpha=0.3)
|
| 565 |
+
|
| 566 |
+
# Residuals
|
| 567 |
+
if 'residual' in decomposition and decomposition['residual'] is not None:
|
| 568 |
+
axes[3].plot(decomposition['residual'])
|
| 569 |
+
axes[3].set_ylabel('Residuals')
|
| 570 |
+
axes[3].set_xlabel('Date')
|
| 571 |
+
axes[3].grid(True, alpha=0.3)
|
| 572 |
+
|
| 573 |
+
plt.tight_layout()
|
| 574 |
+
plt.savefig(
|
| 575 |
+
f'{self.config.results_dir}/plots/decomposition_{target_col}.png',
|
| 576 |
+
dpi=300,
|
| 577 |
+
bbox_inches='tight'
|
| 578 |
+
)
|
| 579 |
+
plt.show()
|
| 580 |
+
|
| 581 |
+
# Additional plots
|
| 582 |
+
self._plot_decomposition_components(data, target_col, decomposition)
|
| 583 |
+
|
| 584 |
+
def _plot_decomposition_components(
|
| 585 |
+
self,
|
| 586 |
+
data: pd.DataFrame,
|
| 587 |
+
target_col: str,
|
| 588 |
+
decomposition: Dict
|
| 589 |
+
) -> None:
|
| 590 |
+
"""Visualise decomposition components"""
|
| 591 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 592 |
+
|
| 593 |
+
# 1. Sum of components vs original series
|
| 594 |
+
if all(k in decomposition for k in ['trend', 'seasonal', 'residual']):
|
| 595 |
+
reconstructed = decomposition['trend'] + decomposition['seasonal'] + decomposition['residual']
|
| 596 |
+
axes[0, 0].plot(decomposition['observed'], alpha=0.7, label='Original')
|
| 597 |
+
axes[0, 0].plot(reconstructed, alpha=0.7, label='Reconstructed')
|
| 598 |
+
axes[0, 0].set_title('Original vs Reconstructed Series')
|
| 599 |
+
axes[0, 0].set_xlabel('Date')
|
| 600 |
+
axes[0, 0].set_ylabel(target_col)
|
| 601 |
+
axes[0, 0].legend()
|
| 602 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 603 |
+
|
| 604 |
+
# 2. Residuals distribution
|
| 605 |
+
if 'residual' in decomposition and decomposition['residual'] is not None:
|
| 606 |
+
residuals = decomposition['residual'].dropna()
|
| 607 |
+
axes[0, 1].hist(residuals, bins=30, edgecolor='black', alpha=0.7, density=True)
|
| 608 |
+
|
| 609 |
+
# Normal distribution for comparison
|
| 610 |
+
xmin, xmax = axes[0, 1].get_xlim()
|
| 611 |
+
x = np.linspace(xmin, xmax, 100)
|
| 612 |
+
p = stats.norm.pdf(x, residuals.mean(), residuals.std())
|
| 613 |
+
axes[0, 1].plot(x, p, 'k', linewidth=2, label='Normal distribution')
|
| 614 |
+
|
| 615 |
+
axes[0, 1].set_title('Residuals Distribution')
|
| 616 |
+
axes[0, 1].set_xlabel('Residuals')
|
| 617 |
+
axes[0, 1].set_ylabel('Density')
|
| 618 |
+
axes[0, 1].legend()
|
| 619 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 620 |
+
|
| 621 |
+
# 3. ACF of residuals
|
| 622 |
+
if 'residual' in decomposition and decomposition['residual'] is not None:
|
| 623 |
+
plot_acf(decomposition['residual'].dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
|
| 624 |
+
axes[1, 0].set_title('Residuals ACF')
|
| 625 |
+
axes[1, 0].set_xlabel('Lag')
|
| 626 |
+
axes[1, 0].set_ylabel('Autocorrelation')
|
| 627 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 628 |
+
|
| 629 |
+
# 4. Seasonal pattern
|
| 630 |
+
if 'seasonal' in decomposition and decomposition['seasonal'] is not None:
|
| 631 |
+
seasonal = decomposition['seasonal']
|
| 632 |
+
if isinstance(seasonal.index, pd.DatetimeIndex):
|
| 633 |
+
# Group by months
|
| 634 |
+
try:
|
| 635 |
+
monthly_seasonal = seasonal.groupby(seasonal.index.month).mean()
|
| 636 |
+
axes[1, 1].bar(monthly_seasonal.index, monthly_seasonal.values)
|
| 637 |
+
axes[1, 1].set_title('Average Seasonal Pattern by Month')
|
| 638 |
+
axes[1, 1].set_xlabel('Month')
|
| 639 |
+
axes[1, 1].set_ylabel('Seasonality')
|
| 640 |
+
axes[1, 1].set_xticks(range(1, 13))
|
| 641 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 642 |
+
except:
|
| 643 |
+
axes[1, 1].plot(seasonal.index, seasonal.values)
|
| 644 |
+
axes[1, 1].set_title('Seasonal Component')
|
| 645 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 646 |
+
|
| 647 |
+
plt.tight_layout()
|
| 648 |
+
plt.savefig(
|
| 649 |
+
f'{self.config.results_dir}/plots/decomposition_components_{target_col}.png',
|
| 650 |
+
dpi=300,
|
| 651 |
+
bbox_inches='tight'
|
| 652 |
+
)
|
| 653 |
+
plt.show()
|
| 654 |
+
|
| 655 |
+
def _plot_residuals_analysis(self, residuals, target_col: str) -> None:
|
| 656 |
+
"""Visualise residual analysis"""
|
| 657 |
+
if residuals is None:
|
| 658 |
+
return
|
| 659 |
+
|
| 660 |
+
residuals_clean = residuals.dropna()
|
| 661 |
+
|
| 662 |
+
if len(residuals_clean) == 0:
|
| 663 |
+
return
|
| 664 |
+
|
| 665 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
|
| 666 |
+
|
| 667 |
+
# Q-Q plot
|
| 668 |
+
stats.probplot(residuals_clean, dist="norm", plot=axes[0])
|
| 669 |
+
axes[0].set_title('Residuals Q-Q plot')
|
| 670 |
+
axes[0].grid(True, alpha=0.3)
|
| 671 |
+
|
| 672 |
+
# Residuals over time
|
| 673 |
+
axes[1].plot(residuals_clean.index, residuals_clean.values, linewidth=0.5)
|
| 674 |
+
axes[1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
|
| 675 |
+
axes[1].set_title('Residuals Over Time')
|
| 676 |
+
axes[1].set_xlabel('Date')
|
| 677 |
+
axes[1].set_ylabel('Residuals')
|
| 678 |
+
axes[1].grid(True, alpha=0.3)
|
| 679 |
+
|
| 680 |
+
plt.tight_layout()
|
| 681 |
+
plt.savefig(
|
| 682 |
+
f'{self.config.results_dir}/plots/residuals_analysis_{target_col}.png',
|
| 683 |
+
dpi=300,
|
| 684 |
+
bbox_inches='tight'
|
| 685 |
+
)
|
| 686 |
+
plt.show()
|
| 687 |
+
|
| 688 |
+
def get_report(self) -> Dict:
|
| 689 |
+
"""Get decomposition report"""
|
| 690 |
+
return self.decomposition_results
|
feature_selection/__init__.py
ADDED
|
File without changes
|
feature_selection/feature_selector.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 11: FEATURE SELECTION
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
from venv import logger
|
| 6 |
+
from config.config import Config
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import seaborn as sns
|
| 13 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 14 |
+
from sklearn.decomposition import PCA
|
| 15 |
+
from sklearn.preprocessing import StandardScaler
|
| 16 |
+
print("✅ All imports working!")
|
| 17 |
+
except ImportError as e:
|
| 18 |
+
print(f"❌ Import error: {e}")
|
| 19 |
+
|
| 20 |
+
from sklearn.inspection import permutation_importance, partial_dependence
|
| 21 |
+
from sklearn.feature_selection import (
|
| 22 |
+
SelectKBest, SelectPercentile, RFE, RFECV, VarianceThreshold,
|
| 23 |
+
f_regression, mutual_info_regression
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
class FeatureSelector:
|
| 27 |
+
"""Class for selecting the most important features"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, config: Config):
|
| 30 |
+
"""
|
| 31 |
+
Initialise feature selector
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
-----------
|
| 35 |
+
config : Config
|
| 36 |
+
Experiment configuration
|
| 37 |
+
"""
|
| 38 |
+
self.config = config
|
| 39 |
+
self.selected_features = []
|
| 40 |
+
self.feature_importances = {}
|
| 41 |
+
self.selection_methods = {}
|
| 42 |
+
self.selector_objects = {}
|
| 43 |
+
|
| 44 |
+
def select(
|
| 45 |
+
self,
|
| 46 |
+
data: pd.DataFrame,
|
| 47 |
+
target_col: Optional[str] = None,
|
| 48 |
+
method: str = None,
|
| 49 |
+
n_features: int = None,
|
| 50 |
+
**kwargs
|
| 51 |
+
) -> pd.DataFrame:
|
| 52 |
+
"""
|
| 53 |
+
Select the most important features
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
-----------
|
| 57 |
+
data : pd.DataFrame
|
| 58 |
+
Input data
|
| 59 |
+
target_col : str, optional
|
| 60 |
+
Target variable. If None, uses configuration value.
|
| 61 |
+
method : str, optional
|
| 62 |
+
Selection method. If None, uses configuration value.
|
| 63 |
+
n_features : int, optional
|
| 64 |
+
Number of features to select. If None, uses configuration value.
|
| 65 |
+
**kwargs : dict
|
| 66 |
+
Additional parameters for method
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
--------
|
| 70 |
+
pd.DataFrame
|
| 71 |
+
Data with selected features
|
| 72 |
+
"""
|
| 73 |
+
logger.info("\n" + "="*80)
|
| 74 |
+
logger.info("FEATURE SELECTION")
|
| 75 |
+
logger.info("="*80)
|
| 76 |
+
|
| 77 |
+
target_col = target_col or self.config.target_column
|
| 78 |
+
method = method or self.config.feature_selection_method
|
| 79 |
+
n_features = n_features or self.config.max_features
|
| 80 |
+
|
| 81 |
+
if target_col not in data.columns:
|
| 82 |
+
logger.error(f"Target variable '{target_col}' not found")
|
| 83 |
+
return data
|
| 84 |
+
|
| 85 |
+
# Prepare data
|
| 86 |
+
X = data.drop(columns=[target_col]).select_dtypes(include=[np.number])
|
| 87 |
+
y = data[target_col]
|
| 88 |
+
|
| 89 |
+
# Remove missing values
|
| 90 |
+
mask = X.notna().all(axis=1) & y.notna()
|
| 91 |
+
X_clean = X[mask]
|
| 92 |
+
y_clean = y[mask]
|
| 93 |
+
|
| 94 |
+
if len(X_clean) < 10 or len(X_clean.columns) < 2:
|
| 95 |
+
logger.warning("Insufficient data for feature selection")
|
| 96 |
+
return data
|
| 97 |
+
|
| 98 |
+
logger.info(f"Selection method: {method}")
|
| 99 |
+
logger.info(f"Target number of features: {n_features}")
|
| 100 |
+
logger.info(f"Initial number of features: {len(X.columns)}")
|
| 101 |
+
logger.info(f"Data for selection: {len(X_clean)} records")
|
| 102 |
+
|
| 103 |
+
# Apply selection method
|
| 104 |
+
selected_features_list = []
|
| 105 |
+
feature_importance_dict = {}
|
| 106 |
+
|
| 107 |
+
if method == 'correlation':
|
| 108 |
+
selected_features_list, feature_importance_dict = self._correlation_selection(
|
| 109 |
+
X_clean, y_clean, n_features, **kwargs
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
elif method == 'mutual_info':
|
| 113 |
+
selected_features_list, feature_importance_dict = self._mutual_info_selection(
|
| 114 |
+
X_clean, y_clean, n_features, **kwargs
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
elif method == 'rf':
|
| 118 |
+
selected_features_list, feature_importance_dict = self._random_forest_selection(
|
| 119 |
+
X_clean, y_clean, n_features, **kwargs
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
elif method == 'pca':
|
| 123 |
+
selected_features_list, feature_importance_dict = self._pca_selection(
|
| 124 |
+
X_clean, y_clean, n_features, **kwargs
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
elif method == 'rfe':
|
| 128 |
+
selected_features_list, feature_importance_dict = self._rfe_selection(
|
| 129 |
+
X_clean, y_clean, n_features, **kwargs
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
elif method == 'lasso':
|
| 133 |
+
selected_features_list, feature_importance_dict = self._lasso_selection(
|
| 134 |
+
X_clean, y_clean, n_features, **kwargs
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
elif method == 'hybrid':
|
| 138 |
+
selected_features_list, feature_importance_dict = self._hybrid_selection(
|
| 139 |
+
X_clean, y_clean, n_features, **kwargs
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
logger.warning(f"Method {method} not supported, using correlation")
|
| 144 |
+
selected_features_list, feature_importance_dict = self._correlation_selection(
|
| 145 |
+
X_clean, y_clean, n_features, **kwargs
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Save selected features
|
| 149 |
+
self.selected_features = selected_features_list
|
| 150 |
+
self.feature_importances = feature_importance_dict
|
| 151 |
+
self.selection_methods[method] = {
|
| 152 |
+
'selected_features': selected_features_list,
|
| 153 |
+
'n_features': len(selected_features_list),
|
| 154 |
+
'feature_importances': feature_importance_dict
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Form final dataset
|
| 158 |
+
features_to_keep = selected_features_list + [target_col]
|
| 159 |
+
features_to_keep = [f for f in features_to_keep if f in data.columns]
|
| 160 |
+
|
| 161 |
+
data_selected = data[features_to_keep].copy()
|
| 162 |
+
|
| 163 |
+
logger.info(f"✓ Selected {len(selected_features_list)} features")
|
| 164 |
+
logger.info(f" Total features kept: {len(data_selected.columns)}")
|
| 165 |
+
|
| 166 |
+
# Visualisation
|
| 167 |
+
if self.config.save_plots and selected_features_list:
|
| 168 |
+
self._plot_feature_selection(
|
| 169 |
+
X_clean, y_clean, selected_features_list,
|
| 170 |
+
feature_importance_dict, method
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return data_selected
|
| 174 |
+
|
| 175 |
+
def _correlation_selection(
|
| 176 |
+
self,
|
| 177 |
+
X: pd.DataFrame,
|
| 178 |
+
y: pd.Series,
|
| 179 |
+
n_features: int,
|
| 180 |
+
**kwargs
|
| 181 |
+
) -> Tuple[List[str], Dict]:
|
| 182 |
+
"""Feature selection based on correlation"""
|
| 183 |
+
# Calculate correlations with target variable
|
| 184 |
+
correlations = X.corrwith(y).abs().sort_values(ascending=False)
|
| 185 |
+
|
| 186 |
+
# Select top-n_features
|
| 187 |
+
selected_features = correlations.head(n_features).index.tolist()
|
| 188 |
+
feature_importance = correlations.to_dict()
|
| 189 |
+
|
| 190 |
+
return selected_features, feature_importance
|
| 191 |
+
|
| 192 |
+
def _mutual_info_selection(
|
| 193 |
+
self,
|
| 194 |
+
X: pd.DataFrame,
|
| 195 |
+
y: pd.Series,
|
| 196 |
+
n_features: int,
|
| 197 |
+
**kwargs
|
| 198 |
+
) -> Tuple[List[str], Dict]:
|
| 199 |
+
"""Feature selection based on mutual information"""
|
| 200 |
+
try:
|
| 201 |
+
mi_scores = mutual_info_regression(X, y, random_state=kwargs.get('random_state', 42))
|
| 202 |
+
mi_series = pd.Series(mi_scores, index=X.columns)
|
| 203 |
+
mi_series = mi_series.sort_values(ascending=False)
|
| 204 |
+
|
| 205 |
+
selected_features = mi_series.head(n_features).index.tolist()
|
| 206 |
+
feature_importance = mi_series.to_dict()
|
| 207 |
+
|
| 208 |
+
return selected_features, feature_importance
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.warning(f"Mutual information selection failed: {e}, using correlation")
|
| 212 |
+
return self._correlation_selection(X, y, n_features, **kwargs)
|
| 213 |
+
|
| 214 |
+
def _random_forest_selection(
|
| 215 |
+
self,
|
| 216 |
+
X: pd.DataFrame,
|
| 217 |
+
y: pd.Series,
|
| 218 |
+
n_features: int,
|
| 219 |
+
**kwargs
|
| 220 |
+
) -> Tuple[List[str], Dict]:
|
| 221 |
+
"""Feature selection based on Random Forest"""
|
| 222 |
+
try:
|
| 223 |
+
rf = RandomForestRegressor(
|
| 224 |
+
n_estimators=kwargs.get('n_estimators', 100),
|
| 225 |
+
max_depth=kwargs.get('max_depth', None),
|
| 226 |
+
random_state=kwargs.get('random_state', 42),
|
| 227 |
+
n_jobs=self.config.n_jobs if self.config.use_multiprocessing else None
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
rf.fit(X, y)
|
| 231 |
+
importances = pd.Series(rf.feature_importances_, index=X.columns)
|
| 232 |
+
importances = importances.sort_values(ascending=False)
|
| 233 |
+
|
| 234 |
+
selected_features = importances.head(n_features).index.tolist()
|
| 235 |
+
feature_importance = importances.to_dict()
|
| 236 |
+
|
| 237 |
+
self.selector_objects['random_forest'] = rf
|
| 238 |
+
|
| 239 |
+
return selected_features, feature_importance
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.warning(f"Random Forest selection failed: {e}, using correlation")
|
| 243 |
+
return self._correlation_selection(X, y, n_features, **kwargs)
|
| 244 |
+
|
| 245 |
+
def _pca_selection(
|
| 246 |
+
self,
|
| 247 |
+
X: pd.DataFrame,
|
| 248 |
+
y: pd.Series,
|
| 249 |
+
n_features: int,
|
| 250 |
+
**kwargs
|
| 251 |
+
) -> Tuple[List[str], Dict]:
|
| 252 |
+
"""Feature selection based on PCA"""
|
| 253 |
+
try:
|
| 254 |
+
# First standardise data
|
| 255 |
+
from sklearn.preprocessing import StandardScaler
|
| 256 |
+
|
| 257 |
+
scaler = StandardScaler()
|
| 258 |
+
X_scaled = scaler.fit_transform(X)
|
| 259 |
+
|
| 260 |
+
# Apply PCA
|
| 261 |
+
pca = PCA(n_components=min(n_features, len(X.columns)))
|
| 262 |
+
X_pca = pca.fit_transform(X_scaled)
|
| 263 |
+
|
| 264 |
+
# Get feature importance via absolute component values
|
| 265 |
+
importance = np.abs(pca.components_).sum(axis=0)
|
| 266 |
+
importance_series = pd.Series(importance, index=X.columns)
|
| 267 |
+
importance_series = importance_series.sort_values(ascending=False)
|
| 268 |
+
|
| 269 |
+
selected_features = importance_series.head(n_features).index.tolist()
|
| 270 |
+
feature_importance = importance_series.to_dict()
|
| 271 |
+
|
| 272 |
+
self.selector_objects['pca'] = pca
|
| 273 |
+
self.selector_objects['scaler'] = scaler
|
| 274 |
+
|
| 275 |
+
return selected_features, feature_importance
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.warning(f"PCA selection failed: {e}, using correlation")
|
| 279 |
+
return self._correlation_selection(X, y, n_features, **kwargs)
|
| 280 |
+
|
| 281 |
+
def _rfe_selection(
|
| 282 |
+
self,
|
| 283 |
+
X: pd.DataFrame,
|
| 284 |
+
y: pd.Series,
|
| 285 |
+
n_features: int,
|
| 286 |
+
**kwargs
|
| 287 |
+
) -> Tuple[List[str], Dict]:
|
| 288 |
+
"""Recursive Feature Elimination"""
|
| 289 |
+
try:
|
| 290 |
+
from sklearn.feature_selection import RFE
|
| 291 |
+
from sklearn.linear_model import LinearRegression
|
| 292 |
+
|
| 293 |
+
estimator = LinearRegression()
|
| 294 |
+
rfe = RFE(
|
| 295 |
+
estimator=estimator,
|
| 296 |
+
n_features_to_select=n_features,
|
| 297 |
+
step=kwargs.get('step', 1)
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
rfe.fit(X, y)
|
| 301 |
+
selected_mask = rfe.support_
|
| 302 |
+
selected_features = X.columns[selected_mask].tolist()
|
| 303 |
+
|
| 304 |
+
# Feature importance via ranking
|
| 305 |
+
ranking = pd.Series(rfe.ranking_, index=X.columns)
|
| 306 |
+
feature_importance = (1 / ranking).to_dict() # Convert ranking to importance
|
| 307 |
+
|
| 308 |
+
self.selector_objects['rfe'] = rfe
|
| 309 |
+
|
| 310 |
+
return selected_features, feature_importance
|
| 311 |
+
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.warning(f"RFE selection failed: {e}, using correlation")
|
| 314 |
+
return self._correlation_selection(X, y, n_features, **kwargs)
|
| 315 |
+
|
| 316 |
+
def _lasso_selection(
|
| 317 |
+
self,
|
| 318 |
+
X: pd.DataFrame,
|
| 319 |
+
y: pd.Series,
|
| 320 |
+
n_features: int,
|
| 321 |
+
**kwargs
|
| 322 |
+
) -> Tuple[List[str], Dict]:
|
| 323 |
+
"""Feature selection using Lasso"""
|
| 324 |
+
try:
|
| 325 |
+
from sklearn.linear_model import LassoCV
|
| 326 |
+
|
| 327 |
+
lasso = LassoCV(
|
| 328 |
+
cv=kwargs.get('cv', 5),
|
| 329 |
+
random_state=kwargs.get('random_state', 42),
|
| 330 |
+
max_iter=kwargs.get('max_iter', 1000)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
lasso.fit(X, y)
|
| 334 |
+
|
| 335 |
+
# Features with non-zero coefficients
|
| 336 |
+
coefficients = pd.Series(lasso.coef_, index=X.columns)
|
| 337 |
+
non_zero_features = coefficients[coefficients != 0].abs().sort_values(ascending=False)
|
| 338 |
+
|
| 339 |
+
# Select top-n_features
|
| 340 |
+
selected_features = non_zero_features.head(n_features).index.tolist()
|
| 341 |
+
feature_importance = non_zero_features.to_dict()
|
| 342 |
+
|
| 343 |
+
self.selector_objects['lasso'] = lasso
|
| 344 |
+
|
| 345 |
+
return selected_features, feature_importance
|
| 346 |
+
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.warning(f"Lasso selection failed: {e}, using correlation")
|
| 349 |
+
return self._correlation_selection(X, y, n_features, **kwargs)
|
| 350 |
+
|
| 351 |
+
def _hybrid_selection(
|
| 352 |
+
self,
|
| 353 |
+
X: pd.DataFrame,
|
| 354 |
+
y: pd.Series,
|
| 355 |
+
n_features: int,
|
| 356 |
+
**kwargs
|
| 357 |
+
) -> Tuple[List[str], Dict]:
|
| 358 |
+
"""Hybrid feature selection method"""
|
| 359 |
+
# Combine multiple methods
|
| 360 |
+
methods = kwargs.get('methods', ['correlation', 'mutual_info', 'rf'])
|
| 361 |
+
weights = kwargs.get('weights', [0.3, 0.3, 0.4])
|
| 362 |
+
|
| 363 |
+
all_importances = {}
|
| 364 |
+
|
| 365 |
+
for method, weight in zip(methods, weights):
|
| 366 |
+
try:
|
| 367 |
+
if method == 'correlation':
|
| 368 |
+
_, importance = self._correlation_selection(X, y, n_features, **kwargs)
|
| 369 |
+
elif method == 'mutual_info':
|
| 370 |
+
_, importance = self._mutual_info_selection(X, y, n_features, **kwargs)
|
| 371 |
+
elif method == 'rf':
|
| 372 |
+
_, importance = self._random_forest_selection(X, y, n_features, **kwargs)
|
| 373 |
+
else:
|
| 374 |
+
continue
|
| 375 |
+
|
| 376 |
+
# Normalise importances and weight them
|
| 377 |
+
importance_series = pd.Series(importance)
|
| 378 |
+
if importance_series.max() > importance_series.min():
|
| 379 |
+
importance_normalized = (importance_series - importance_series.min()) / \
|
| 380 |
+
(importance_series.max() - importance_series.min())
|
| 381 |
+
else:
|
| 382 |
+
importance_normalized = pd.Series(1, index=importance_series.index)
|
| 383 |
+
|
| 384 |
+
# Add weighted importances
|
| 385 |
+
for feature in importance_normalized.index:
|
| 386 |
+
if feature not in all_importances:
|
| 387 |
+
all_importances[feature] = 0
|
| 388 |
+
all_importances[feature] += importance_normalized[feature] * weight
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.debug(f"Method {method} failed in hybrid selection: {e}")
|
| 392 |
+
|
| 393 |
+
# Sort by total importance
|
| 394 |
+
combined_importance = pd.Series(all_importances).sort_values(ascending=False)
|
| 395 |
+
selected_features = combined_importance.head(n_features).index.tolist()
|
| 396 |
+
|
| 397 |
+
return selected_features, combined_importance.to_dict()
|
| 398 |
+
|
| 399 |
+
def _plot_feature_selection(
|
| 400 |
+
self,
|
| 401 |
+
X: pd.DataFrame,
|
| 402 |
+
y: pd.Series,
|
| 403 |
+
selected_features: List[str],
|
| 404 |
+
feature_importance: Dict,
|
| 405 |
+
method: str
|
| 406 |
+
) -> None:
|
| 407 |
+
"""Visualise feature selection results"""
|
| 408 |
+
# Prepare data for visualisation
|
| 409 |
+
importance_series = pd.Series(feature_importance).sort_values(ascending=False)
|
| 410 |
+
|
| 411 |
+
# Limit number of features for display
|
| 412 |
+
display_features = importance_series.head(20)
|
| 413 |
+
|
| 414 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 415 |
+
|
| 416 |
+
# 1. Feature importance
|
| 417 |
+
y_pos = np.arange(len(display_features))
|
| 418 |
+
axes[0, 0].barh(y_pos, display_features.values)
|
| 419 |
+
axes[0, 0].set_yticks(y_pos)
|
| 420 |
+
axes[0, 0].set_yticklabels(display_features.index, fontsize=9)
|
| 421 |
+
axes[0, 0].invert_yaxis()
|
| 422 |
+
axes[0, 0].set_xlabel('Importance')
|
| 423 |
+
axes[0, 0].set_title(f'Top-{len(display_features)} features by importance ({method})')
|
| 424 |
+
axes[0, 0].grid(True, alpha=0.3, axis='x')
|
| 425 |
+
|
| 426 |
+
# 2. Cumulative importance
|
| 427 |
+
cumulative_importance = importance_series.cumsum() / importance_series.sum()
|
| 428 |
+
axes[0, 1].plot(range(1, len(cumulative_importance) + 1), cumulative_importance.values)
|
| 429 |
+
axes[0, 1].axhline(y=0.8, color='r', linestyle='--', alpha=0.7, label='80% importance')
|
| 430 |
+
axes[0, 1].axhline(y=0.9, color='orange', linestyle='--', alpha=0.7, label='90% importance')
|
| 431 |
+
axes[0, 1].set_xlabel('Number of features')
|
| 432 |
+
axes[0, 1].set_ylabel('Cumulative importance')
|
| 433 |
+
axes[0, 1].set_title('Cumulative feature importance')
|
| 434 |
+
axes[0, 1].legend()
|
| 435 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 436 |
+
|
| 437 |
+
# 3. Correlation matrix of selected features
|
| 438 |
+
if len(selected_features) > 1:
|
| 439 |
+
selected_X = X[selected_features]
|
| 440 |
+
corr_matrix = selected_X.corr()
|
| 441 |
+
|
| 442 |
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
| 443 |
+
sns.heatmap(
|
| 444 |
+
corr_matrix,
|
| 445 |
+
annot=True,
|
| 446 |
+
fmt='.2f',
|
| 447 |
+
cmap='coolwarm',
|
| 448 |
+
center=0,
|
| 449 |
+
square=True,
|
| 450 |
+
mask=mask,
|
| 451 |
+
cbar_kws={'shrink': 0.8},
|
| 452 |
+
ax=axes[1, 0]
|
| 453 |
+
)
|
| 454 |
+
axes[1, 0].set_title(f'Correlation of selected features ({len(selected_features)})')
|
| 455 |
+
|
| 456 |
+
# 4. Importance distribution
|
| 457 |
+
axes[1, 1].hist(importance_series.values, bins=30, edgecolor='black', alpha=0.7)
|
| 458 |
+
axes[1, 1].set_xlabel('Feature importance')
|
| 459 |
+
axes[1, 1].set_ylabel('Frequency')
|
| 460 |
+
axes[1, 1].set_title('Feature importance distribution')
|
| 461 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 462 |
+
|
| 463 |
+
plt.suptitle(f'Feature selection results using {method} method', fontsize=14)
|
| 464 |
+
plt.tight_layout()
|
| 465 |
+
plt.savefig(
|
| 466 |
+
f'{self.config.results_dir}/plots/feature_selection_{method}.png',
|
| 467 |
+
dpi=300,
|
| 468 |
+
bbox_inches='tight'
|
| 469 |
+
)
|
| 470 |
+
plt.show()
|
| 471 |
+
|
| 472 |
+
def get_report(self) -> Dict:
|
| 473 |
+
"""Get feature selection report"""
|
| 474 |
+
return {
|
| 475 |
+
'selected_features': self.selected_features,
|
| 476 |
+
'feature_importances': self.feature_importances,
|
| 477 |
+
'selection_methods': self.selection_methods
|
| 478 |
+
}
|
features/__init__.py
ADDED
|
File without changes
|
features/feature_engineer.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 5: FEATURE ENGINEER
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
from venv import logger
|
| 6 |
+
|
| 7 |
+
from config.config import Config
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureEngineer:
|
| 14 |
+
"""Class for creating new features for time series"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config: Config):
|
| 17 |
+
"""
|
| 18 |
+
Initialise feature engineer
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
-----------
|
| 22 |
+
config : Config
|
| 23 |
+
Experiment configuration
|
| 24 |
+
"""
|
| 25 |
+
self.config = config
|
| 26 |
+
self.created_features = []
|
| 27 |
+
self.feature_info = {}
|
| 28 |
+
self.feature_importances = {}
|
| 29 |
+
self.transforms_applied = {}
|
| 30 |
+
|
| 31 |
+
def create_all_features(
|
| 32 |
+
self,
|
| 33 |
+
data: pd.DataFrame,
|
| 34 |
+
target_col: Optional[str] = None
|
| 35 |
+
) -> pd.DataFrame:
|
| 36 |
+
"""
|
| 37 |
+
Create all types of features
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
-----------
|
| 41 |
+
data : pd.DataFrame
|
| 42 |
+
Input data
|
| 43 |
+
target_col : str, optional
|
| 44 |
+
Target variable. If None, uses configuration value.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
--------
|
| 48 |
+
pd.DataFrame
|
| 49 |
+
Data with all features
|
| 50 |
+
"""
|
| 51 |
+
logger.info("\n" + "="*80)
|
| 52 |
+
logger.info("CREATING FEATURES FOR TIME SERIES")
|
| 53 |
+
logger.info("="*80)
|
| 54 |
+
|
| 55 |
+
target_col = target_col or self.config.target_column
|
| 56 |
+
initial_features = len(data.columns)
|
| 57 |
+
initial_rows = len(data)
|
| 58 |
+
|
| 59 |
+
# Check and save index
|
| 60 |
+
original_index = data.index
|
| 61 |
+
index_is_datetime = isinstance(original_index, pd.DatetimeIndex)
|
| 62 |
+
|
| 63 |
+
logger.info(f"Initial number of features: {initial_features}")
|
| 64 |
+
logger.info(f"Initial number of rows: {initial_rows}")
|
| 65 |
+
logger.info(f"Index is DatetimeIndex: {index_is_datetime}")
|
| 66 |
+
|
| 67 |
+
# If index not DatetimeIndex but 'date' column exists
|
| 68 |
+
if not index_is_datetime and 'date' in data.columns:
|
| 69 |
+
logger.info("Attempting to set DatetimeIndex from 'date' column")
|
| 70 |
+
try:
|
| 71 |
+
data = data.set_index('date')
|
| 72 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 73 |
+
index_is_datetime = True
|
| 74 |
+
original_index = data.index
|
| 75 |
+
logger.info("✓ DatetimeIndex set from 'date' column")
|
| 76 |
+
else:
|
| 77 |
+
logger.warning("Failed to set DatetimeIndex")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.warning(f"Error setting DatetimeIndex: {e}")
|
| 80 |
+
|
| 81 |
+
# Save data copy for index restoration later
|
| 82 |
+
data_processed = data.copy()
|
| 83 |
+
|
| 84 |
+
# 1. Create basic temporal features (if date exists)
|
| 85 |
+
if index_is_datetime:
|
| 86 |
+
logger.info("\n1. BASIC TEMPORAL FEATURES")
|
| 87 |
+
data_processed = self.create_temporal_features(data_processed)
|
| 88 |
+
else:
|
| 89 |
+
logger.info("\n1. BASIC TEMPORAL FEATURES: skipped (no DatetimeIndex)")
|
| 90 |
+
|
| 91 |
+
# 2. Create statistical features
|
| 92 |
+
logger.info("\n2. STATISTICAL FEATURES")
|
| 93 |
+
data_processed = self.create_statistical_features(data_processed, target_col)
|
| 94 |
+
|
| 95 |
+
# 3. Create rolling features
|
| 96 |
+
logger.info("\n3. ROLLING FEATURES")
|
| 97 |
+
data_processed = self.create_rolling_features(data_processed, target_col)
|
| 98 |
+
|
| 99 |
+
# 4. Create lag features (limited quantity)
|
| 100 |
+
logger.info("\n4. LAG FEATURES")
|
| 101 |
+
data_processed = self.create_lag_features(data_processed, target_col)
|
| 102 |
+
|
| 103 |
+
# 5. Create interaction features
|
| 104 |
+
logger.info("\n5. INTERACTION FEATURES")
|
| 105 |
+
data_processed = self.create_interaction_features(data_processed, target_col)
|
| 106 |
+
|
| 107 |
+
# 6. Create spectral features (only if sufficient data)
|
| 108 |
+
logger.info("\n6. SPECTRAL FEATURES")
|
| 109 |
+
if len(data_processed) > 100:
|
| 110 |
+
data_processed = self.create_spectral_features(data_processed, target_col)
|
| 111 |
+
else:
|
| 112 |
+
logger.info(" Skipped: insufficient data")
|
| 113 |
+
|
| 114 |
+
# 7. Create decomposition features (only if sufficient data and date exists)
|
| 115 |
+
logger.info("\n7. DECOMPOSITION FEATURES")
|
| 116 |
+
if len(data_processed) > 365 and index_is_datetime:
|
| 117 |
+
data_processed = self.create_decomposition_features(data_processed, target_col)
|
| 118 |
+
else:
|
| 119 |
+
logger.info(" Skipped: insufficient data or no DatetimeIndex")
|
| 120 |
+
|
| 121 |
+
# Remove rows with NaN that appeared due to lags and differences
|
| 122 |
+
rows_before_nan = len(data_processed)
|
| 123 |
+
data_processed = data_processed.dropna()
|
| 124 |
+
rows_after_nan = len(data_processed)
|
| 125 |
+
removed_rows = rows_before_nan - rows_after_nan
|
| 126 |
+
|
| 127 |
+
# Remove constant features
|
| 128 |
+
constant_features = []
|
| 129 |
+
for col in data_processed.columns:
|
| 130 |
+
if data_processed[col].nunique() <= 1:
|
| 131 |
+
constant_features.append(col)
|
| 132 |
+
|
| 133 |
+
if constant_features:
|
| 134 |
+
logger.info(f"\nRemoving constant features: {len(constant_features)} found")
|
| 135 |
+
for feat in constant_features[:10]:
|
| 136 |
+
logger.info(f" - {feat}")
|
| 137 |
+
if len(constant_features) > 10:
|
| 138 |
+
logger.info(f" ... and {len(constant_features) - 10} more features")
|
| 139 |
+
|
| 140 |
+
data_processed = data_processed.drop(columns=constant_features)
|
| 141 |
+
# Update created features list
|
| 142 |
+
self.created_features = [f for f in self.created_features if f not in constant_features]
|
| 143 |
+
|
| 144 |
+
# Save information
|
| 145 |
+
self.feature_info = {
|
| 146 |
+
'initial_features': initial_features,
|
| 147 |
+
'final_features': len(data_processed.columns),
|
| 148 |
+
'features_created': len(self.created_features),
|
| 149 |
+
'initial_rows': initial_rows,
|
| 150 |
+
'final_rows': len(data_processed),
|
| 151 |
+
'removed_rows': removed_rows,
|
| 152 |
+
'constant_features_removed': len(constant_features),
|
| 153 |
+
'created_features_list': self.created_features,
|
| 154 |
+
'feature_categories': self.get_feature_categories()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
logger.info(f"\nFeature creation summary:")
|
| 158 |
+
logger.info(f" Initial number of features: {initial_features}")
|
| 159 |
+
logger.info(f" Final number of features: {len(data_processed.columns)}")
|
| 160 |
+
logger.info(f" New features created: {len(self.created_features)}")
|
| 161 |
+
logger.info(f" Initial number of rows: {initial_rows}")
|
| 162 |
+
logger.info(f" Final number of rows: {len(data_processed)}")
|
| 163 |
+
logger.info(f" Rows removed due to NaN: {removed_rows}")
|
| 164 |
+
logger.info(f" Constant features removed: {len(constant_features)}")
|
| 165 |
+
|
| 166 |
+
return data_processed
|
| 167 |
+
|
| 168 |
+
def create_temporal_features(self, data: pd.DataFrame) -> pd.DataFrame:
|
| 169 |
+
"""
|
| 170 |
+
Create temporal features
|
| 171 |
+
|
| 172 |
+
Parameters:
|
| 173 |
+
-----------
|
| 174 |
+
data : pd.DataFrame
|
| 175 |
+
Input data
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
--------
|
| 179 |
+
pd.DataFrame
|
| 180 |
+
Data with temporal features
|
| 181 |
+
"""
|
| 182 |
+
data_processed = data.copy()
|
| 183 |
+
|
| 184 |
+
if not isinstance(data_processed.index, pd.DatetimeIndex):
|
| 185 |
+
logger.warning("Temporal features not created: index not DatetimeIndex")
|
| 186 |
+
return data_processed
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
# Basic temporal features
|
| 190 |
+
data_processed['year'] = data_processed.index.year
|
| 191 |
+
data_processed['month'] = data_processed.index.month
|
| 192 |
+
data_processed['day'] = data_processed.index.day
|
| 193 |
+
data_processed['dayofyear'] = data_processed.index.dayofyear
|
| 194 |
+
data_processed['dayofweek'] = data_processed.index.dayofweek
|
| 195 |
+
data_processed['weekofyear'] = data_processed.index.isocalendar().week.astype(int)
|
| 196 |
+
data_processed['quarter'] = data_processed.index.quarter
|
| 197 |
+
data_processed['is_weekend'] = data_processed['dayofweek'].isin([5, 6]).astype(int)
|
| 198 |
+
|
| 199 |
+
# Cyclic features for seasonality
|
| 200 |
+
data_processed['month_sin'] = np.sin(2 * np.pi * data_processed['month'] / 12)
|
| 201 |
+
data_processed['month_cos'] = np.cos(2 * np.pi * data_processed['month'] / 12)
|
| 202 |
+
data_processed['dayofyear_sin'] = np.sin(2 * np.pi * data_processed['dayofyear'] / 365.25)
|
| 203 |
+
data_processed['dayofyear_cos'] = np.cos(2 * np.pi * data_processed['dayofyear'] / 365.25)
|
| 204 |
+
data_processed['dayofweek_sin'] = np.sin(2 * np.pi * data_processed['dayofweek'] / 7)
|
| 205 |
+
data_processed['dayofweek_cos'] = np.cos(2 * np.pi * data_processed['dayofweek'] / 7)
|
| 206 |
+
|
| 207 |
+
# Time in days from start (relative features)
|
| 208 |
+
min_date = data_processed.index.min()
|
| 209 |
+
data_processed['days_from_start'] = (data_processed.index - min_date).days
|
| 210 |
+
|
| 211 |
+
# Register created features
|
| 212 |
+
temporal_features = ['year', 'month', 'day', 'dayofyear', 'dayofweek',
|
| 213 |
+
'weekofyear', 'quarter', 'is_weekend', 'month_sin',
|
| 214 |
+
'month_cos', 'dayofyear_sin', 'dayofyear_cos',
|
| 215 |
+
'dayofweek_sin', 'dayofweek_cos', 'days_from_start']
|
| 216 |
+
|
| 217 |
+
self.created_features.extend([f for f in temporal_features if f not in self.created_features])
|
| 218 |
+
|
| 219 |
+
logger.info(f"✓ Created {len(temporal_features)} temporal features")
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"Error creating temporal features: {e}")
|
| 223 |
+
|
| 224 |
+
return data_processed
|
| 225 |
+
|
| 226 |
+
def create_statistical_features(
|
| 227 |
+
self,
|
| 228 |
+
data: pd.DataFrame,
|
| 229 |
+
target_col: str
|
| 230 |
+
) -> pd.DataFrame:
|
| 231 |
+
"""
|
| 232 |
+
Create statistical features
|
| 233 |
+
|
| 234 |
+
Parameters:
|
| 235 |
+
-----------
|
| 236 |
+
data : pd.DataFrame
|
| 237 |
+
Input data
|
| 238 |
+
target_col : str
|
| 239 |
+
Target variable
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
--------
|
| 243 |
+
pd.DataFrame
|
| 244 |
+
Data with statistical features
|
| 245 |
+
"""
|
| 246 |
+
data_processed = data.copy()
|
| 247 |
+
|
| 248 |
+
if target_col not in data_processed.columns:
|
| 249 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 250 |
+
return data_processed
|
| 251 |
+
|
| 252 |
+
# Only if we have year data
|
| 253 |
+
if 'year' in data_processed.columns:
|
| 254 |
+
# Yearly statistics
|
| 255 |
+
try:
|
| 256 |
+
yearly_stats = data_processed.groupby('year')[target_col].agg([
|
| 257 |
+
'mean', 'std', 'min', 'max', 'median'
|
| 258 |
+
])
|
| 259 |
+
yearly_stats.columns = [f'{target_col}_yearly_{col}' for col in yearly_stats.columns]
|
| 260 |
+
data_processed = data_processed.merge(yearly_stats, on='year', how='left')
|
| 261 |
+
|
| 262 |
+
# Add created features to list
|
| 263 |
+
for col in yearly_stats.columns:
|
| 264 |
+
self.created_features.append(col)
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.debug(f"Yearly statistics not created: {e}")
|
| 267 |
+
|
| 268 |
+
# Normalised features (only if there is variation)
|
| 269 |
+
std_val = data_processed[target_col].std()
|
| 270 |
+
if std_val > 0:
|
| 271 |
+
data_processed[f'{target_col}_zscore'] = (data_processed[target_col] - data_processed[target_col].mean()) / std_val
|
| 272 |
+
self.created_features.append(f'{target_col}_zscore')
|
| 273 |
+
|
| 274 |
+
# Features based on percentiles (binary features)
|
| 275 |
+
try:
|
| 276 |
+
for p in [0.25, 0.5, 0.75]:
|
| 277 |
+
quantile_val = data_processed[target_col].quantile(p)
|
| 278 |
+
data_processed[f'{target_col}_above_p{int(p*100)}'] = (data_processed[target_col] > quantile_val).astype(int)
|
| 279 |
+
self.created_features.append(f'{target_col}_above_p{int(p*100)}')
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logger.debug(f"Quantile features not created: {e}")
|
| 282 |
+
|
| 283 |
+
logger.info(f"✓ Statistical features created: {len([c for c in data_processed.columns if c not in data.columns])}")
|
| 284 |
+
return data_processed
|
| 285 |
+
|
| 286 |
+
def create_rolling_features(
|
| 287 |
+
self,
|
| 288 |
+
data: pd.DataFrame,
|
| 289 |
+
target_col: str
|
| 290 |
+
) -> pd.DataFrame:
|
| 291 |
+
"""
|
| 292 |
+
Create rolling statistics
|
| 293 |
+
|
| 294 |
+
Parameters:
|
| 295 |
+
-----------
|
| 296 |
+
data : pd.DataFrame
|
| 297 |
+
Input data
|
| 298 |
+
target_col : str
|
| 299 |
+
Target variable
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
--------
|
| 303 |
+
pd.DataFrame
|
| 304 |
+
Data with rolling features
|
| 305 |
+
"""
|
| 306 |
+
data_processed = data.copy()
|
| 307 |
+
|
| 308 |
+
if target_col not in data_processed.columns:
|
| 309 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 310 |
+
return data_processed
|
| 311 |
+
|
| 312 |
+
# Use only main windows from configuration
|
| 313 |
+
windows = [w for w in self.config.rolling_windows if w < len(data_processed) // 2]
|
| 314 |
+
|
| 315 |
+
for window in windows:
|
| 316 |
+
try:
|
| 317 |
+
# Basic statistics
|
| 318 |
+
data_processed[f'{target_col}_rolling_mean_{window}'] = data_processed[target_col].rolling(
|
| 319 |
+
window=window, min_periods=max(1, window//4), center=True
|
| 320 |
+
).mean()
|
| 321 |
+
|
| 322 |
+
data_processed[f'{target_col}_rolling_std_{window}'] = data_processed[target_col].rolling(
|
| 323 |
+
window=window, min_periods=max(1, window//4), center=True
|
| 324 |
+
).std()
|
| 325 |
+
|
| 326 |
+
data_processed[f'{target_col}_rolling_min_{window}'] = data_processed[target_col].rolling(
|
| 327 |
+
window=window, min_periods=max(1, window//4), center=True
|
| 328 |
+
).min()
|
| 329 |
+
|
| 330 |
+
data_processed[f'{target_col}_rolling_max_{window}'] = data_processed[target_col].rolling(
|
| 331 |
+
window=window, min_periods=max(1, window//4), center=True
|
| 332 |
+
).max()
|
| 333 |
+
|
| 334 |
+
self.created_features.extend([
|
| 335 |
+
f'{target_col}_rolling_mean_{window}',
|
| 336 |
+
f'{target_col}_rolling_std_{window}',
|
| 337 |
+
f'{target_col}_rolling_min_{window}',
|
| 338 |
+
f'{target_col}_rolling_max_{window}'
|
| 339 |
+
])
|
| 340 |
+
except Exception as e:
|
| 341 |
+
logger.debug(f"Rolling features for window {window} not created: {e}")
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
logger.info(f"✓ Rolling features created: {len([c for c in data_processed.columns if 'rolling' in c and c not in data.columns])}")
|
| 345 |
+
return data_processed
|
| 346 |
+
|
| 347 |
+
def create_lag_features(
|
| 348 |
+
self,
|
| 349 |
+
data: pd.DataFrame,
|
| 350 |
+
target_col: str
|
| 351 |
+
) -> pd.DataFrame:
|
| 352 |
+
"""
|
| 353 |
+
Create lag features
|
| 354 |
+
|
| 355 |
+
Parameters:
|
| 356 |
+
-----------
|
| 357 |
+
data : pd.DataFrame
|
| 358 |
+
Input data
|
| 359 |
+
target_col : str
|
| 360 |
+
Target variable
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
--------
|
| 364 |
+
pd.DataFrame
|
| 365 |
+
Data with lag features
|
| 366 |
+
"""
|
| 367 |
+
data_processed = data.copy()
|
| 368 |
+
|
| 369 |
+
if target_col not in data_processed.columns:
|
| 370 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 371 |
+
return data_processed
|
| 372 |
+
|
| 373 |
+
# Limited number of lags
|
| 374 |
+
max_lags = min(self.config.max_lags, 7) # Maximum 7 lags
|
| 375 |
+
|
| 376 |
+
for lag in [1, 2, 3, 7, 14, 30]:
|
| 377 |
+
if lag <= max_lags:
|
| 378 |
+
data_processed[f'{target_col}_lag_{lag}'] = data_processed[target_col].shift(lag)
|
| 379 |
+
self.created_features.append(f'{target_col}_lag_{lag}')
|
| 380 |
+
|
| 381 |
+
# Seasonal lags (only if sufficient data)
|
| 382 |
+
if len(data_processed) > 365:
|
| 383 |
+
try:
|
| 384 |
+
data_processed[f'{target_col}_seasonal_lag_365'] = data_processed[target_col].shift(365)
|
| 385 |
+
self.created_features.append(f'{target_col}_seasonal_lag_365')
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.debug(f"Seasonal lag not created: {e}")
|
| 388 |
+
|
| 389 |
+
# Differences (stationarity)
|
| 390 |
+
data_processed[f'{target_col}_diff_1'] = data_processed[target_col].diff(1)
|
| 391 |
+
self.created_features.append(f'{target_col}_diff_1')
|
| 392 |
+
|
| 393 |
+
if len(data_processed) > 7:
|
| 394 |
+
data_processed[f'{target_col}_diff_7'] = data_processed[target_col].diff(7)
|
| 395 |
+
self.created_features.append(f'{target_col}_diff_7')
|
| 396 |
+
|
| 397 |
+
logger.info(f"✓ Lag features created: {len([c for c in data_processed.columns if ('lag' in c or 'diff' in c) and c not in data.columns])}")
|
| 398 |
+
return data_processed
|
| 399 |
+
|
| 400 |
+
def create_interaction_features(
|
| 401 |
+
self,
|
| 402 |
+
data: pd.DataFrame,
|
| 403 |
+
target_col: str
|
| 404 |
+
) -> pd.DataFrame:
|
| 405 |
+
"""
|
| 406 |
+
Create interaction features
|
| 407 |
+
|
| 408 |
+
Parameters:
|
| 409 |
+
-----------
|
| 410 |
+
data : pd.DataFrame
|
| 411 |
+
Input data
|
| 412 |
+
target_col : str
|
| 413 |
+
Target variable
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
--------
|
| 417 |
+
pd.DataFrame
|
| 418 |
+
Data with interaction features
|
| 419 |
+
"""
|
| 420 |
+
data_processed = data.copy()
|
| 421 |
+
|
| 422 |
+
if target_col not in data_processed.columns:
|
| 423 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 424 |
+
return data_processed
|
| 425 |
+
|
| 426 |
+
# Interactions with temperature (only if data exists)
|
| 427 |
+
temp_cols = ['tavg', 'tmin', 'tmax']
|
| 428 |
+
available_temp_cols = [col for col in temp_cols if col in data_processed.columns]
|
| 429 |
+
|
| 430 |
+
for temp_col in available_temp_cols:
|
| 431 |
+
try:
|
| 432 |
+
# Avoid division by zero
|
| 433 |
+
temp_data = data_processed[temp_col].replace(0, np.nan)
|
| 434 |
+
if temp_data.notna().all() and (temp_data != 0).all():
|
| 435 |
+
data_processed[f'{target_col}_{temp_col}_ratio'] = data_processed[target_col] / temp_data
|
| 436 |
+
self.created_features.append(f'{target_col}_{temp_col}_ratio')
|
| 437 |
+
|
| 438 |
+
# Product
|
| 439 |
+
data_processed[f'{target_col}_{temp_col}_product'] = data_processed[target_col] * temp_data
|
| 440 |
+
self.created_features.append(f'{target_col}_{temp_col}_product')
|
| 441 |
+
except Exception as e:
|
| 442 |
+
logger.debug(f"Interaction feature with {temp_col} not created: {e}")
|
| 443 |
+
|
| 444 |
+
# Interaction with water level
|
| 445 |
+
if 'urovenvoda' in data_processed.columns:
|
| 446 |
+
try:
|
| 447 |
+
uroven_data = data_processed['urovenvoda'].replace(0, np.nan)
|
| 448 |
+
if uroven_data.notna().all() and (uroven_data != 0).all():
|
| 449 |
+
data_processed[f'{target_col}_urovenvoda_ratio'] = data_processed[target_col] / uroven_data
|
| 450 |
+
self.created_features.append(f'{target_col}_urovenvoda_ratio')
|
| 451 |
+
except Exception as e:
|
| 452 |
+
logger.debug(f"Interaction feature with urovenvoda not created: {e}")
|
| 453 |
+
|
| 454 |
+
logger.info(f"✓ Interaction features created: {len([c for c in data_processed.columns if ('ratio' in c or 'product' in c) and c not in data.columns])}")
|
| 455 |
+
return data_processed
|
| 456 |
+
|
| 457 |
+
def create_spectral_features(
|
| 458 |
+
self,
|
| 459 |
+
data: pd.DataFrame,
|
| 460 |
+
target_col: str
|
| 461 |
+
) -> pd.DataFrame:
|
| 462 |
+
"""
|
| 463 |
+
Create spectral features
|
| 464 |
+
|
| 465 |
+
Parameters:
|
| 466 |
+
-----------
|
| 467 |
+
data : pd.DataFrame
|
| 468 |
+
Input data
|
| 469 |
+
target_col : str
|
| 470 |
+
Target variable
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
--------
|
| 474 |
+
pd.DataFrame
|
| 475 |
+
Data with spectral features
|
| 476 |
+
"""
|
| 477 |
+
data_processed = data.copy()
|
| 478 |
+
|
| 479 |
+
if target_col not in data_processed.columns:
|
| 480 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 481 |
+
return data_processed
|
| 482 |
+
|
| 483 |
+
if len(data_processed) < 100:
|
| 484 |
+
logger.info("Insufficient data for creating spectral features")
|
| 485 |
+
return data_processed
|
| 486 |
+
|
| 487 |
+
try:
|
| 488 |
+
# Fast Fourier Transform
|
| 489 |
+
series = data_processed[target_col].dropna().values
|
| 490 |
+
|
| 491 |
+
if len(series) > 50:
|
| 492 |
+
# Calculate periodogram
|
| 493 |
+
from scipy.signal import periodogram
|
| 494 |
+
freqs, psd = periodogram(series, fs=1.0)
|
| 495 |
+
|
| 496 |
+
# Find dominant frequencies
|
| 497 |
+
if len(psd) > 3:
|
| 498 |
+
# Top-3 frequencies by power
|
| 499 |
+
top_indices = np.argsort(psd)[-3:][::-1]
|
| 500 |
+
|
| 501 |
+
for i, idx in enumerate(top_indices, 1):
|
| 502 |
+
if idx < len(freqs):
|
| 503 |
+
freq = freqs[idx]
|
| 504 |
+
if freq > 0:
|
| 505 |
+
period = 1 / freq
|
| 506 |
+
data_processed[f'{target_col}_dominant_period_{i}'] = period
|
| 507 |
+
self.created_features.append(f'{target_col}_dominant_period_{i}')
|
| 508 |
+
|
| 509 |
+
except Exception as e:
|
| 510 |
+
logger.debug(f"Spectral features creation failed: {e}")
|
| 511 |
+
|
| 512 |
+
return data_processed
|
| 513 |
+
|
| 514 |
+
def create_decomposition_features(
|
| 515 |
+
self,
|
| 516 |
+
data: pd.DataFrame,
|
| 517 |
+
target_col: str
|
| 518 |
+
) -> pd.DataFrame:
|
| 519 |
+
"""
|
| 520 |
+
Create features based on decomposition
|
| 521 |
+
|
| 522 |
+
Parameters:
|
| 523 |
+
-----------
|
| 524 |
+
data : pd.DataFrame
|
| 525 |
+
Input data
|
| 526 |
+
target_col : str
|
| 527 |
+
Target variable
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
--------
|
| 531 |
+
pd.DataFrame
|
| 532 |
+
Data with decomposition features
|
| 533 |
+
"""
|
| 534 |
+
data_processed = data.copy()
|
| 535 |
+
|
| 536 |
+
if target_col not in data_processed.columns:
|
| 537 |
+
logger.warning(f"Target variable '{target_col}' not found")
|
| 538 |
+
return data_processed
|
| 539 |
+
|
| 540 |
+
if len(data_processed) < 365:
|
| 541 |
+
logger.info("Insufficient data for decomposition")
|
| 542 |
+
return data_processed
|
| 543 |
+
|
| 544 |
+
try:
|
| 545 |
+
# Check for date presence
|
| 546 |
+
if isinstance(data_processed.index, pd.DatetimeIndex):
|
| 547 |
+
# STL decomposition
|
| 548 |
+
if len(data_processed) > 730: # Need at least 2 years for yearly seasonality
|
| 549 |
+
try:
|
| 550 |
+
from statsmodels.tsa.seasonal import STL
|
| 551 |
+
|
| 552 |
+
# STL decomposition
|
| 553 |
+
stl = STL(
|
| 554 |
+
data_processed[target_col].fillna(method='ffill'),
|
| 555 |
+
period=365,
|
| 556 |
+
robust=True
|
| 557 |
+
)
|
| 558 |
+
result = stl.fit()
|
| 559 |
+
|
| 560 |
+
# Add components
|
| 561 |
+
data_processed[f'{target_col}_trend'] = result.trend
|
| 562 |
+
data_processed[f'{target_col}_seasonal'] = result.seasonal
|
| 563 |
+
data_processed[f'{target_col}_residual'] = result.resid
|
| 564 |
+
|
| 565 |
+
self.created_features.extend([
|
| 566 |
+
f'{target_col}_trend',
|
| 567 |
+
f'{target_col}_seasonal',
|
| 568 |
+
f'{target_col}_residual'
|
| 569 |
+
])
|
| 570 |
+
|
| 571 |
+
logger.info("✓ STL decomposition successful")
|
| 572 |
+
|
| 573 |
+
except Exception as e:
|
| 574 |
+
logger.debug(f"STL decomposition failed: {e}")
|
| 575 |
+
# Simple seasonal decomposition
|
| 576 |
+
try:
|
| 577 |
+
from statsmodels.tsa.seasonal import seasonal_decompose
|
| 578 |
+
|
| 579 |
+
decomposition = seasonal_decompose(
|
| 580 |
+
data_processed[target_col].fillna(method='ffill'),
|
| 581 |
+
model='additive',
|
| 582 |
+
period=365,
|
| 583 |
+
extrapolate_trend='freq'
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
data_processed[f'{target_col}_trend'] = decomposition.trend
|
| 587 |
+
data_processed[f'{target_col}_seasonal'] = decomposition.seasonal
|
| 588 |
+
|
| 589 |
+
self.created_features.extend([
|
| 590 |
+
f'{target_col}_trend',
|
| 591 |
+
f'{target_col}_seasonal'
|
| 592 |
+
])
|
| 593 |
+
|
| 594 |
+
logger.info("✓ Seasonal decomposition successful")
|
| 595 |
+
except Exception as e2:
|
| 596 |
+
logger.debug(f"Seasonal decomposition failed: {e2}")
|
| 597 |
+
|
| 598 |
+
except Exception as e:
|
| 599 |
+
logger.debug(f"Decomposition features creation failed: {e}")
|
| 600 |
+
|
| 601 |
+
return data_processed
|
| 602 |
+
|
| 603 |
+
def get_feature_categories(self) -> Dict[str, List[str]]:
|
| 604 |
+
"""Get features by categories"""
|
| 605 |
+
categories = {
|
| 606 |
+
'temporal': [],
|
| 607 |
+
'statistical': [],
|
| 608 |
+
'rolling': [],
|
| 609 |
+
'lag': [],
|
| 610 |
+
'interaction': [],
|
| 611 |
+
'spectral': [],
|
| 612 |
+
'decomposition': [],
|
| 613 |
+
'binary': []
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
for feature in self.created_features:
|
| 617 |
+
if any(keyword in feature for keyword in ['year', 'month', 'day', 'week', 'quarter', 'sin', 'cos', 'is_weekend']):
|
| 618 |
+
categories['temporal'].append(feature)
|
| 619 |
+
elif any(keyword in feature for keyword in ['zscore', 'above_p', 'yearly_']):
|
| 620 |
+
if 'above_p' in feature:
|
| 621 |
+
categories['binary'].append(feature)
|
| 622 |
+
else:
|
| 623 |
+
categories['statistical'].append(feature)
|
| 624 |
+
elif 'rolling' in feature:
|
| 625 |
+
categories['rolling'].append(feature)
|
| 626 |
+
elif any(keyword in feature for keyword in ['lag', 'diff']):
|
| 627 |
+
categories['lag'].append(feature)
|
| 628 |
+
elif 'ratio' in feature or 'product' in feature:
|
| 629 |
+
categories['interaction'].append(feature)
|
| 630 |
+
elif 'dominant' in feature:
|
| 631 |
+
categories['spectral'].append(feature)
|
| 632 |
+
elif any(keyword in feature for keyword in ['trend', 'seasonal', 'residual']):
|
| 633 |
+
categories['decomposition'].append(feature)
|
| 634 |
+
|
| 635 |
+
# Remove empty categories
|
| 636 |
+
categories = {k: v for k, v in categories.items() if v}
|
| 637 |
+
|
| 638 |
+
return categories
|
missing_values/__init__.py
ADDED
|
File without changes
|
missing_values/missing_analyzer.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 3: MISSING VALUE ANALYSER
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, Tuple
|
| 5 |
+
from venv import logger
|
| 6 |
+
|
| 7 |
+
from config.config import Config
|
| 8 |
+
from scipy.interpolate import interp1d
|
| 9 |
+
from statsmodels.tsa.seasonal import seasonal_decompose, STL
|
| 10 |
+
try:
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
print("✅ All imports working!")
|
| 15 |
+
except ImportError as e:
|
| 16 |
+
print(f"❌ Import error: {e}")
|
| 17 |
+
|
| 18 |
+
class MissingValueAnalyser:
|
| 19 |
+
"""Class for analysing and handling missing values"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: Config):
|
| 22 |
+
"""
|
| 23 |
+
Initialise missing value analyser
|
| 24 |
+
|
| 25 |
+
Parameters:
|
| 26 |
+
-----------
|
| 27 |
+
config : Config
|
| 28 |
+
Experiment configuration
|
| 29 |
+
"""
|
| 30 |
+
self.config = config
|
| 31 |
+
self.missing_info = {}
|
| 32 |
+
self.handling_methods = {}
|
| 33 |
+
self.imputers = {}
|
| 34 |
+
self.missing_patterns = {}
|
| 35 |
+
|
| 36 |
+
def analyse(
|
| 37 |
+
self,
|
| 38 |
+
data: pd.DataFrame,
|
| 39 |
+
detailed: bool = True
|
| 40 |
+
) -> Dict:
|
| 41 |
+
"""
|
| 42 |
+
Analyse missing values in data
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
-----------
|
| 46 |
+
data : pd.DataFrame
|
| 47 |
+
Input data
|
| 48 |
+
detailed : bool
|
| 49 |
+
Whether to perform detailed analysis
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
--------
|
| 53 |
+
Dict
|
| 54 |
+
Information about missing values
|
| 55 |
+
"""
|
| 56 |
+
logger.info("\n" + "="*80)
|
| 57 |
+
logger.info("MISSING VALUE ANALYSIS")
|
| 58 |
+
logger.info("="*80)
|
| 59 |
+
|
| 60 |
+
# Calculate missing values
|
| 61 |
+
missing_total = data.isnull().sum()
|
| 62 |
+
missing_percent = (missing_total / len(data)) * 100
|
| 63 |
+
|
| 64 |
+
missing_df = pd.DataFrame({
|
| 65 |
+
'missing_count': missing_total,
|
| 66 |
+
'missing_percent': missing_percent,
|
| 67 |
+
'dtype': data.dtypes.astype(str)
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
# Detailed analysis
|
| 71 |
+
if detailed:
|
| 72 |
+
self._detailed_missing_analysis(data, missing_df)
|
| 73 |
+
|
| 74 |
+
# Save information
|
| 75 |
+
self.missing_info = {
|
| 76 |
+
'summary': {
|
| 77 |
+
col: {
|
| 78 |
+
'missing_count': int(missing_df.loc[col, 'missing_count']),
|
| 79 |
+
'missing_percent': float(missing_df.loc[col, 'missing_percent']),
|
| 80 |
+
'dtype': missing_df.loc[col, 'dtype']
|
| 81 |
+
}
|
| 82 |
+
for col in missing_df.index
|
| 83 |
+
},
|
| 84 |
+
'overall': {
|
| 85 |
+
'total_missing': int(missing_total.sum()),
|
| 86 |
+
'total_rows': int(len(data)),
|
| 87 |
+
'total_cells': int(data.size),
|
| 88 |
+
'overall_missing_percentage': float(missing_total.sum() / data.size * 100),
|
| 89 |
+
'rows_with_any_missing': int(data.isnull().any(axis=1).sum()),
|
| 90 |
+
'rows_all_missing': int(data.isnull().all(axis=1).sum()),
|
| 91 |
+
'columns_with_missing': missing_df[missing_df['missing_count'] > 0].index.tolist(),
|
| 92 |
+
'columns_all_missing': missing_df[missing_df['missing_count'] == len(data)].index.tolist()
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# Visualisation
|
| 97 |
+
if self.config.save_plots:
|
| 98 |
+
self._plot_missing_values(data, missing_df)
|
| 99 |
+
|
| 100 |
+
# Output results
|
| 101 |
+
self._log_missing_summary(missing_df)
|
| 102 |
+
|
| 103 |
+
return self.missing_info
|
| 104 |
+
|
| 105 |
+
def _detailed_missing_analysis(
|
| 106 |
+
self,
|
| 107 |
+
data: pd.DataFrame,
|
| 108 |
+
missing_df: pd.DataFrame
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Detailed missing value analysis"""
|
| 111 |
+
# Analyse missing patterns
|
| 112 |
+
missing_matrix = data.isnull()
|
| 113 |
+
|
| 114 |
+
# Row missing patterns
|
| 115 |
+
row_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=1)
|
| 116 |
+
row_pattern_counts = row_patterns.value_counts().head(10)
|
| 117 |
+
|
| 118 |
+
# Column missing patterns
|
| 119 |
+
col_patterns = missing_matrix.apply(lambda x: ''.join(x.astype(int).astype(str)), axis=0)
|
| 120 |
+
col_pattern_counts = col_patterns.value_counts().head(10)
|
| 121 |
+
|
| 122 |
+
# Time-based missing patterns analysis
|
| 123 |
+
time_patterns = {}
|
| 124 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 125 |
+
# Missing values by time
|
| 126 |
+
time_missing = data.isnull().resample('M').sum()
|
| 127 |
+
time_patterns['monthly_missing'] = time_missing.sum(axis=1).to_dict()
|
| 128 |
+
|
| 129 |
+
# Missing values by day of week
|
| 130 |
+
data_with_dow = data.copy()
|
| 131 |
+
data_with_dow['dayofweek'] = data.index.dayofweek
|
| 132 |
+
dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
|
| 133 |
+
time_patterns['dayofweek_missing'] = dow_missing.to_dict()
|
| 134 |
+
|
| 135 |
+
self.missing_patterns = {
|
| 136 |
+
'row_patterns': row_pattern_counts.to_dict(),
|
| 137 |
+
'col_patterns': col_pattern_counts.to_dict(),
|
| 138 |
+
'time_patterns': time_patterns,
|
| 139 |
+
'missing_correlation': missing_matrix.corr().to_dict() # Missing value correlation
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
logger.debug(f"Found {len(row_pattern_counts)} unique row missing patterns")
|
| 143 |
+
logger.debug(f"Found {len(col_pattern_counts)} unique column missing patterns")
|
| 144 |
+
|
| 145 |
+
def _plot_missing_values(
|
| 146 |
+
self,
|
| 147 |
+
data: pd.DataFrame,
|
| 148 |
+
missing_df: pd.DataFrame
|
| 149 |
+
) -> None:
|
| 150 |
+
"""Visualise missing values"""
|
| 151 |
+
fig, axes = plt.subplots(3, 2, figsize=(16, 12))
|
| 152 |
+
|
| 153 |
+
# 1. Missing percentage histogram
|
| 154 |
+
axes[0, 0].barh(
|
| 155 |
+
missing_df.index,
|
| 156 |
+
missing_df['missing_percent']
|
| 157 |
+
)
|
| 158 |
+
axes[0, 0].axvline(self.config.missing_threshold, color='red', linestyle='--')
|
| 159 |
+
axes[0, 0].set_title('Missing Percentage by Column')
|
| 160 |
+
axes[0, 0].set_xlabel('Missing Percentage (%)')
|
| 161 |
+
axes[0, 0].set_ylabel('Columns')
|
| 162 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 163 |
+
|
| 164 |
+
# 2. Missing values heatmap
|
| 165 |
+
missing_matrix = data.isnull()
|
| 166 |
+
axes[0, 1].imshow(
|
| 167 |
+
missing_matrix.T if len(data) > 1000 else missing_matrix.T[:1000],
|
| 168 |
+
aspect='auto',
|
| 169 |
+
cmap='binary',
|
| 170 |
+
interpolation='none'
|
| 171 |
+
)
|
| 172 |
+
axes[0, 1].set_title('Missing Values Matrix')
|
| 173 |
+
axes[0, 1].set_xlabel('Observation Index')
|
| 174 |
+
axes[0, 1].set_ylabel('Variables')
|
| 175 |
+
axes[0, 1].set_yticks(range(len(data.columns)))
|
| 176 |
+
axes[0, 1].set_yticklabels(data.columns, fontsize=8)
|
| 177 |
+
|
| 178 |
+
# 3. Missing values over time (if time series)
|
| 179 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 180 |
+
time_missing = data.isnull().resample('M').sum()
|
| 181 |
+
|
| 182 |
+
axes[1, 0].plot(time_missing.sum(axis=1))
|
| 183 |
+
axes[1, 0].set_title('Missing Values by Month')
|
| 184 |
+
axes[1, 0].set_xlabel('Date')
|
| 185 |
+
axes[1, 0].set_ylabel('Number of Missing Values')
|
| 186 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 187 |
+
|
| 188 |
+
# 4. Missing values by day of week
|
| 189 |
+
data_with_dow = data.copy()
|
| 190 |
+
data_with_dow['dayofweek'] = data.index.dayofweek
|
| 191 |
+
dow_missing = data_with_dow.groupby('dayofweek').apply(lambda x: x.isnull().sum().sum())
|
| 192 |
+
dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
| 193 |
+
|
| 194 |
+
axes[1, 1].bar(range(7), dow_missing)
|
| 195 |
+
axes[1, 1].set_title('Missing Values by Day of Week')
|
| 196 |
+
axes[1, 1].set_xlabel('Day of Week')
|
| 197 |
+
axes[1, 1].set_ylabel('Number of Missing Values')
|
| 198 |
+
axes[1, 1].set_xticks(range(7))
|
| 199 |
+
axes[1, 1].set_xticklabels(dow_names)
|
| 200 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 201 |
+
|
| 202 |
+
# 5. Missing value correlation
|
| 203 |
+
missing_corr = data.isnull().corr()
|
| 204 |
+
im = axes[2, 0].imshow(
|
| 205 |
+
missing_corr,
|
| 206 |
+
cmap='coolwarm',
|
| 207 |
+
vmin=-1,
|
| 208 |
+
vmax=1,
|
| 209 |
+
aspect='auto'
|
| 210 |
+
)
|
| 211 |
+
axes[2, 0].set_title('Missing Value Correlation Between Variables')
|
| 212 |
+
axes[2, 0].set_xlabel('Variables')
|
| 213 |
+
axes[2, 0].set_ylabel('Variables')
|
| 214 |
+
plt.colorbar(im, ax=axes[2, 0])
|
| 215 |
+
|
| 216 |
+
# 6. Cumulative missing sum
|
| 217 |
+
cumulative_missing = data.isnull().cumsum()
|
| 218 |
+
for col in data.columns[:5]: # First 5 columns
|
| 219 |
+
if data[col].isnull().any():
|
| 220 |
+
axes[2, 1].plot(
|
| 221 |
+
cumulative_missing.index,
|
| 222 |
+
cumulative_missing[col],
|
| 223 |
+
label=col[:20]
|
| 224 |
+
)
|
| 225 |
+
axes[2, 1].set_title('Cumulative Missing Values')
|
| 226 |
+
axes[2, 1].set_xlabel('Time/Index')
|
| 227 |
+
axes[2, 1].set_ylabel('Cumulative Missing')
|
| 228 |
+
axes[2, 1].legend(fontsize=8)
|
| 229 |
+
axes[2, 1].grid(True, alpha=0.3)
|
| 230 |
+
|
| 231 |
+
plt.tight_layout()
|
| 232 |
+
plt.savefig(
|
| 233 |
+
f'{self.config.results_dir}/plots/missing_values_analysis.png',
|
| 234 |
+
dpi=300,
|
| 235 |
+
bbox_inches='tight'
|
| 236 |
+
)
|
| 237 |
+
plt.show()
|
| 238 |
+
|
| 239 |
+
def _log_missing_summary(self, missing_df: pd.DataFrame) -> None:
|
| 240 |
+
"""Log missing value summary"""
|
| 241 |
+
missing_columns = missing_df[missing_df['missing_count'] > 0]
|
| 242 |
+
|
| 243 |
+
if len(missing_columns) > 0:
|
| 244 |
+
logger.info("MISSING VALUES FOUND:")
|
| 245 |
+
logger.info("-" * 50)
|
| 246 |
+
logger.info(f"Total missing values: {self.missing_info['overall']['total_missing']}")
|
| 247 |
+
logger.info(f"Overall missing percentage: {self.missing_info['overall']['overall_missing_percentage']:.2f}%")
|
| 248 |
+
logger.info(f"Rows with missing values: {self.missing_info['overall']['rows_with_any_missing']}")
|
| 249 |
+
logger.info(f"Columns with missing values: {len(self.missing_info['overall']['columns_with_missing'])}")
|
| 250 |
+
|
| 251 |
+
logger.info("\nTop-10 columns by missing values:")
|
| 252 |
+
top_missing = missing_df.nlargest(10, 'missing_percent')
|
| 253 |
+
for idx, (col, row) in enumerate(top_missing.iterrows(), 1):
|
| 254 |
+
logger.info(f" {idx:2d}. {col}: {int(row['missing_count'])} missing ({row['missing_percent']:.2f}%)")
|
| 255 |
+
else:
|
| 256 |
+
logger.info("✓ No missing values found")
|
| 257 |
+
|
| 258 |
+
def handle(
|
| 259 |
+
self,
|
| 260 |
+
data: pd.DataFrame,
|
| 261 |
+
method: str = 'interpolate',
|
| 262 |
+
strategy: str = 'columnwise',
|
| 263 |
+
**kwargs
|
| 264 |
+
) -> pd.DataFrame:
|
| 265 |
+
"""
|
| 266 |
+
Handle missing values
|
| 267 |
+
|
| 268 |
+
Parameters:
|
| 269 |
+
-----------
|
| 270 |
+
data : pd.DataFrame
|
| 271 |
+
Input data
|
| 272 |
+
method : str
|
| 273 |
+
Handling method: 'interpolate', 'ffill', 'bfill', 'mean', 'median', 'mode', 'knn', 'regression'
|
| 274 |
+
strategy : str
|
| 275 |
+
Strategy: 'columnwise', 'rowwise', 'global'
|
| 276 |
+
**kwargs : dict
|
| 277 |
+
Additional parameters for method
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
--------
|
| 281 |
+
pd.DataFrame
|
| 282 |
+
Data with handled missing values
|
| 283 |
+
"""
|
| 284 |
+
logger.info("\n" + "="*80)
|
| 285 |
+
logger.info("HANDLING MISSING VALUES")
|
| 286 |
+
logger.info("="*80)
|
| 287 |
+
|
| 288 |
+
data_processed = data.copy()
|
| 289 |
+
methods_applied = {}
|
| 290 |
+
|
| 291 |
+
# Determine columns to process
|
| 292 |
+
if strategy == 'columnwise':
|
| 293 |
+
columns_to_process = data_processed.columns
|
| 294 |
+
elif strategy == 'rowwise':
|
| 295 |
+
# Row-wise handling (for time series)
|
| 296 |
+
data_processed = self._handle_rowwise(data_processed, method, **kwargs)
|
| 297 |
+
return data_processed
|
| 298 |
+
else:
|
| 299 |
+
columns_to_process = data_processed.select_dtypes(include=[np.number]).columns
|
| 300 |
+
|
| 301 |
+
# Process each column
|
| 302 |
+
for col in columns_to_process:
|
| 303 |
+
missing_before = data_processed[col].isnull().sum()
|
| 304 |
+
|
| 305 |
+
if missing_before > 0:
|
| 306 |
+
# Check if missing percentage exceeds threshold
|
| 307 |
+
missing_percent = (missing_before / len(data_processed)) * 100
|
| 308 |
+
|
| 309 |
+
if missing_percent > self.config.missing_threshold:
|
| 310 |
+
logger.warning(f" {col}: {missing_before} missing ({missing_percent:.1f}%) > threshold {self.config.missing_threshold}%")
|
| 311 |
+
|
| 312 |
+
if kwargs.get('drop_high_missing', False):
|
| 313 |
+
data_processed = data_processed.drop(columns=[col])
|
| 314 |
+
method_used = f"dropped (>{self.config.missing_threshold}% missing)"
|
| 315 |
+
missing_after = 0
|
| 316 |
+
else:
|
| 317 |
+
# Use selected method
|
| 318 |
+
data_processed[col], method_used = self._apply_imputation_method(
|
| 319 |
+
data_processed[col], method, **kwargs
|
| 320 |
+
)
|
| 321 |
+
missing_after = data_processed[col].isnull().sum()
|
| 322 |
+
else:
|
| 323 |
+
# Use selected method
|
| 324 |
+
data_processed[col], method_used = self._apply_imputation_method(
|
| 325 |
+
data_processed[col], method, **kwargs
|
| 326 |
+
)
|
| 327 |
+
missing_after = data_processed[col].isnull().sum()
|
| 328 |
+
|
| 329 |
+
methods_applied[col] = {
|
| 330 |
+
'method': method_used,
|
| 331 |
+
'missing_before': int(missing_before),
|
| 332 |
+
'missing_after': int(missing_after),
|
| 333 |
+
'missing_percent_before': float(missing_percent)
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
if missing_before > 0:
|
| 337 |
+
logger.info(f" {col}: {missing_before} → {missing_after} missing ({method_used})")
|
| 338 |
+
|
| 339 |
+
self.handling_methods = methods_applied
|
| 340 |
+
|
| 341 |
+
# Check that all missing values are handled
|
| 342 |
+
remaining_missing = data_processed.isnull().sum().sum()
|
| 343 |
+
if remaining_missing == 0:
|
| 344 |
+
logger.info("✓ All missing values successfully handled")
|
| 345 |
+
else:
|
| 346 |
+
logger.warning(f"⚠ {remaining_missing} missing values remain")
|
| 347 |
+
# Additional handling of remaining missing values
|
| 348 |
+
data_processed = data_processed.fillna(method='ffill').fillna(method='bfill')
|
| 349 |
+
remaining_after = data_processed.isnull().sum().sum()
|
| 350 |
+
if remaining_after == 0:
|
| 351 |
+
logger.info("✓ Remaining missing values handled with ffill/bfill combination")
|
| 352 |
+
|
| 353 |
+
return data_processed
|
| 354 |
+
|
| 355 |
+
def _apply_imputation_method(
|
| 356 |
+
self,
|
| 357 |
+
series: pd.Series,
|
| 358 |
+
method: str,
|
| 359 |
+
**kwargs
|
| 360 |
+
) -> Tuple[pd.Series, str]:
|
| 361 |
+
"""
|
| 362 |
+
Apply imputation method to individual series
|
| 363 |
+
|
| 364 |
+
Parameters:
|
| 365 |
+
-----------
|
| 366 |
+
series : pd.Series
|
| 367 |
+
Input series
|
| 368 |
+
method : str
|
| 369 |
+
Imputation method
|
| 370 |
+
**kwargs : dict
|
| 371 |
+
Additional parameters
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
--------
|
| 375 |
+
Tuple[pd.Series, str]
|
| 376 |
+
Processed series and method description
|
| 377 |
+
"""
|
| 378 |
+
if method == 'interpolate':
|
| 379 |
+
# Interpolation for time series
|
| 380 |
+
if isinstance(series.index, pd.DatetimeIndex):
|
| 381 |
+
method_name = f"{kwargs.get('interpolation_method', 'linear')} interpolation"
|
| 382 |
+
series_filled = series.interpolate(
|
| 383 |
+
method=kwargs.get('interpolation_method', 'linear'),
|
| 384 |
+
limit_direction=kwargs.get('limit_direction', 'both'),
|
| 385 |
+
limit=kwargs.get('limit', None)
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
method_name = 'linear interpolation'
|
| 389 |
+
series_filled = series.interpolate(method='linear')
|
| 390 |
+
|
| 391 |
+
elif method == 'time_weighted':
|
| 392 |
+
# Time-weighted interpolation
|
| 393 |
+
method_name = 'time-weighted interpolation'
|
| 394 |
+
series_filled = self._time_weighted_interpolation(series)
|
| 395 |
+
|
| 396 |
+
elif method == 'seasonal':
|
| 397 |
+
# Seasonal interpolation
|
| 398 |
+
method_name = 'seasonal interpolation'
|
| 399 |
+
series_filled = self._seasonal_interpolation(series, **kwargs)
|
| 400 |
+
|
| 401 |
+
elif method == 'ffill':
|
| 402 |
+
# Forward fill
|
| 403 |
+
method_name = 'forward fill'
|
| 404 |
+
series_filled = series.ffill(limit=kwargs.get('limit', None))
|
| 405 |
+
|
| 406 |
+
elif method == 'bfill':
|
| 407 |
+
# Backward fill
|
| 408 |
+
method_name = 'backward fill'
|
| 409 |
+
series_filled = series.bfill(limit=kwargs.get('limit', None))
|
| 410 |
+
|
| 411 |
+
elif method == 'mean':
|
| 412 |
+
# Mean imputation
|
| 413 |
+
method_name = 'mean imputation'
|
| 414 |
+
series_filled = series.fillna(series.mean())
|
| 415 |
+
|
| 416 |
+
elif method == 'median':
|
| 417 |
+
# Median imputation
|
| 418 |
+
method_name = 'median imputation'
|
| 419 |
+
series_filled = series.fillna(series.median())
|
| 420 |
+
|
| 421 |
+
elif method == 'mode':
|
| 422 |
+
# Mode imputation
|
| 423 |
+
method_name = 'mode imputation'
|
| 424 |
+
mode_value = series.mode()
|
| 425 |
+
if not mode_value.empty:
|
| 426 |
+
series_filled = series.fillna(mode_value.iloc[0])
|
| 427 |
+
else:
|
| 428 |
+
series_filled = series.fillna(series.median())
|
| 429 |
+
|
| 430 |
+
elif method == 'knn':
|
| 431 |
+
# KNN imputation
|
| 432 |
+
method_name = f"KNN imputation (k={kwargs.get('k', 5)})"
|
| 433 |
+
# Simplified version using nearest neighbour mean
|
| 434 |
+
series_filled = self._knn_imputation(series, k=kwargs.get('k', 5))
|
| 435 |
+
|
| 436 |
+
elif method == 'regression':
|
| 437 |
+
# Regression imputation
|
| 438 |
+
method_name = 'regression imputation'
|
| 439 |
+
series_filled = self._regression_imputation(series, **kwargs)
|
| 440 |
+
|
| 441 |
+
elif method == 'spline':
|
| 442 |
+
# Spline interpolation
|
| 443 |
+
method_name = 'spline interpolation'
|
| 444 |
+
series_filled = series.interpolate(method='spline', order=kwargs.get('order', 3))
|
| 445 |
+
|
| 446 |
+
elif method == 'stl':
|
| 447 |
+
# STL decomposition + interpolation
|
| 448 |
+
method_name = 'STL-based imputation'
|
| 449 |
+
series_filled = self._stl_imputation(series, **kwargs)
|
| 450 |
+
|
| 451 |
+
else:
|
| 452 |
+
raise ValueError(f"Unknown method: {method}")
|
| 453 |
+
|
| 454 |
+
# If missing values remain, fill with ffill/bfill
|
| 455 |
+
if series_filled.isnull().any():
|
| 456 |
+
series_filled = series_filled.ffill().bfill()
|
| 457 |
+
method_name += " + ffill/bfill"
|
| 458 |
+
|
| 459 |
+
return series_filled, method_name
|
| 460 |
+
|
| 461 |
+
def _time_weighted_interpolation(self, series: pd.Series) -> pd.Series:
|
| 462 |
+
"""Time-weighted interpolation"""
|
| 463 |
+
if not isinstance(series.index, pd.DatetimeIndex):
|
| 464 |
+
return series.interpolate()
|
| 465 |
+
|
| 466 |
+
# Create timestamps
|
| 467 |
+
time_numeric = pd.Series(range(len(series)), index=series.index)
|
| 468 |
+
|
| 469 |
+
# Interpolate timestamps for missing values
|
| 470 |
+
time_interpolated = time_numeric.interpolate()
|
| 471 |
+
|
| 472 |
+
# Interpolate values based on timestamps
|
| 473 |
+
valid_mask = series.notna()
|
| 474 |
+
if valid_mask.sum() < 2:
|
| 475 |
+
return series.ffill().bfill()
|
| 476 |
+
|
| 477 |
+
# Use linear interpolation
|
| 478 |
+
valid_times = time_numeric[valid_mask]
|
| 479 |
+
valid_values = series[valid_mask]
|
| 480 |
+
|
| 481 |
+
# Interpolation
|
| 482 |
+
interp_func = interp1d(
|
| 483 |
+
valid_times,
|
| 484 |
+
valid_values,
|
| 485 |
+
kind='linear',
|
| 486 |
+
bounds_error=False,
|
| 487 |
+
fill_value='extrapolate'
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
series_filled = series.copy()
|
| 491 |
+
missing_mask = series.isna()
|
| 492 |
+
series_filled[missing_mask] = interp_func(time_interpolated[missing_mask])
|
| 493 |
+
|
| 494 |
+
return series_filled
|
| 495 |
+
|
| 496 |
+
def _seasonal_interpolation(
|
| 497 |
+
self,
|
| 498 |
+
series: pd.Series,
|
| 499 |
+
**kwargs
|
| 500 |
+
) -> pd.Series:
|
| 501 |
+
"""Seasonal interpolation"""
|
| 502 |
+
if not isinstance(series.index, pd.DatetimeIndex):
|
| 503 |
+
return series.interpolate()
|
| 504 |
+
|
| 505 |
+
period = kwargs.get('period', self.config.seasonal_period)
|
| 506 |
+
|
| 507 |
+
# Create series copy
|
| 508 |
+
series_filled = series.copy()
|
| 509 |
+
|
| 510 |
+
# Interpolation considering seasonality
|
| 511 |
+
for i in range(len(series)):
|
| 512 |
+
if pd.isna(series.iloc[i]):
|
| 513 |
+
# Find values at same seasonal position
|
| 514 |
+
seasonal_indices = []
|
| 515 |
+
for offset in range(1, 10): # Look in previous/next cycles
|
| 516 |
+
idx_back = i - offset * period
|
| 517 |
+
idx_forward = i + offset * period
|
| 518 |
+
|
| 519 |
+
if idx_back >= 0 and not pd.isna(series.iloc[idx_back]):
|
| 520 |
+
seasonal_indices.append(idx_back)
|
| 521 |
+
|
| 522 |
+
if idx_forward < len(series) and not pd.isna(series.iloc[idx_forward]):
|
| 523 |
+
seasonal_indices.append(idx_forward)
|
| 524 |
+
|
| 525 |
+
if seasonal_indices:
|
| 526 |
+
# Take mean value from seasonal positions
|
| 527 |
+
seasonal_values = series.iloc[seasonal_indices]
|
| 528 |
+
series_filled.iloc[i] = seasonal_values.mean()
|
| 529 |
+
|
| 530 |
+
# Fill remaining missing values with regular interpolation
|
| 531 |
+
series_filled = series_filled.interpolate()
|
| 532 |
+
|
| 533 |
+
return series_filled
|
| 534 |
+
|
| 535 |
+
def _knn_imputation(
|
| 536 |
+
self,
|
| 537 |
+
series: pd.Series,
|
| 538 |
+
k: int = 5
|
| 539 |
+
) -> pd.Series:
|
| 540 |
+
"""KNN imputation for time series"""
|
| 541 |
+
# Simplified KNN for time series
|
| 542 |
+
series_filled = series.copy()
|
| 543 |
+
|
| 544 |
+
for i in range(len(series)):
|
| 545 |
+
if pd.isna(series.iloc[i]):
|
| 546 |
+
# Find nearest k non-missing values
|
| 547 |
+
distances = []
|
| 548 |
+
values = []
|
| 549 |
+
|
| 550 |
+
for j in range(max(0, i - k * 10), min(len(series), i + k * 10)):
|
| 551 |
+
if j != i and not pd.isna(series.iloc[j]):
|
| 552 |
+
distance = abs(i - j)
|
| 553 |
+
distances.append(distance)
|
| 554 |
+
values.append(series.iloc[j])
|
| 555 |
+
|
| 556 |
+
if len(values) >= k:
|
| 557 |
+
break
|
| 558 |
+
|
| 559 |
+
if values:
|
| 560 |
+
# Distance-weighted average
|
| 561 |
+
weights = [1 / (d + 1) for d in distances]
|
| 562 |
+
weighted_avg = np.average(values, weights=weights)
|
| 563 |
+
series_filled.iloc[i] = weighted_avg
|
| 564 |
+
|
| 565 |
+
return series_filled
|
| 566 |
+
|
| 567 |
+
def _regression_imputation(
|
| 568 |
+
self,
|
| 569 |
+
series: pd.Series,
|
| 570 |
+
**kwargs
|
| 571 |
+
) -> pd.Series:
|
| 572 |
+
"""Regression imputation based on neighbouring values"""
|
| 573 |
+
# Simplified regression for time series
|
| 574 |
+
series_filled = series.copy()
|
| 575 |
+
|
| 576 |
+
if series.notna().sum() < 3:
|
| 577 |
+
return series.ffill().bfill()
|
| 578 |
+
|
| 579 |
+
# Use polynomial regression
|
| 580 |
+
x = np.arange(len(series))
|
| 581 |
+
y = series.values
|
| 582 |
+
|
| 583 |
+
# Valid values mask
|
| 584 |
+
valid_mask = ~np.isnan(y)
|
| 585 |
+
|
| 586 |
+
if valid_mask.sum() < 2:
|
| 587 |
+
return series.ffill().bfill()
|
| 588 |
+
|
| 589 |
+
# Polynomial regression degree 2
|
| 590 |
+
coeffs = np.polyfit(x[valid_mask], y[valid_mask], 2)
|
| 591 |
+
poly_func = np.poly1d(coeffs)
|
| 592 |
+
|
| 593 |
+
# Fill missing values
|
| 594 |
+
missing_mask = np.isnan(y)
|
| 595 |
+
series_filled.iloc[missing_mask] = poly_func(x[missing_mask])
|
| 596 |
+
|
| 597 |
+
return series_filled
|
| 598 |
+
|
| 599 |
+
def _stl_imputation(
|
| 600 |
+
self,
|
| 601 |
+
series: pd.Series,
|
| 602 |
+
**kwargs
|
| 603 |
+
) -> pd.Series:
|
| 604 |
+
"""STL decomposition-based imputation"""
|
| 605 |
+
try:
|
| 606 |
+
if not isinstance(series.index, pd.DatetimeIndex):
|
| 607 |
+
return series.interpolate()
|
| 608 |
+
|
| 609 |
+
# STL decomposition
|
| 610 |
+
stl = STL(
|
| 611 |
+
series.ffill().bfill(), # Fill missing for STL
|
| 612 |
+
period=kwargs.get('period', self.config.seasonal_period),
|
| 613 |
+
robust=True
|
| 614 |
+
)
|
| 615 |
+
result = stl.fit()
|
| 616 |
+
|
| 617 |
+
# Reconstruct series without noise
|
| 618 |
+
reconstructed = result.trend + result.seasonal
|
| 619 |
+
|
| 620 |
+
# Replace missing values with reconstructed values
|
| 621 |
+
series_filled = series.copy()
|
| 622 |
+
missing_mask = series.isna()
|
| 623 |
+
series_filled[missing_mask] = reconstructed[missing_mask]
|
| 624 |
+
|
| 625 |
+
return series_filled
|
| 626 |
+
|
| 627 |
+
except Exception as e:
|
| 628 |
+
logger.warning(f"STL imputation failed: {e}, using interpolation")
|
| 629 |
+
return series.interpolate()
|
| 630 |
+
|
| 631 |
+
def _handle_rowwise(
|
| 632 |
+
self,
|
| 633 |
+
data: pd.DataFrame,
|
| 634 |
+
method: str,
|
| 635 |
+
**kwargs
|
| 636 |
+
) -> pd.DataFrame:
|
| 637 |
+
"""Row-wise missing value handling"""
|
| 638 |
+
data_processed = data.copy()
|
| 639 |
+
|
| 640 |
+
# Remove rows with high missing counts
|
| 641 |
+
if kwargs.get('drop_rows_threshold', 0) > 0:
|
| 642 |
+
threshold = kwargs['drop_rows_threshold']
|
| 643 |
+
rows_before = len(data_processed)
|
| 644 |
+
missing_per_row = data_processed.isnull().sum(axis=1) / data_processed.shape[1] * 100
|
| 645 |
+
rows_to_drop = missing_per_row[missing_per_row > threshold].index
|
| 646 |
+
data_processed = data_processed.drop(rows_to_drop)
|
| 647 |
+
rows_after = len(data_processed)
|
| 648 |
+
logger.info(f"Rows removed: {rows_before - rows_after} (missing > {threshold}%)")
|
| 649 |
+
|
| 650 |
+
# Row-wise imputation
|
| 651 |
+
if method == 'row_mean':
|
| 652 |
+
data_processed = data_processed.T.fillna(data_processed.mean(axis=1)).T
|
| 653 |
+
elif method == 'row_median':
|
| 654 |
+
data_processed = data_processed.T.fillna(data_processed.median(axis=1)).T
|
| 655 |
+
elif method == 'row_ffill':
|
| 656 |
+
data_processed = data_processed.ffill(axis=1).bfill(axis=1)
|
| 657 |
+
|
| 658 |
+
return data_processed
|
| 659 |
+
|
| 660 |
+
def create_validation_rules(self) -> Dict:
|
| 661 |
+
"""Create validation rules based on missing value analysis"""
|
| 662 |
+
rules = {}
|
| 663 |
+
|
| 664 |
+
for col, info in self.missing_info['summary'].items():
|
| 665 |
+
missing_percent = info['missing_percent']
|
| 666 |
+
|
| 667 |
+
if missing_percent > 50:
|
| 668 |
+
rules[col] = {
|
| 669 |
+
'action': 'drop_column',
|
| 670 |
+
'reason': f'Missing > 50%: {missing_percent:.1f}%'
|
| 671 |
+
}
|
| 672 |
+
elif missing_percent > 20:
|
| 673 |
+
rules[col] = {
|
| 674 |
+
'action': 'advanced_imputation',
|
| 675 |
+
'reason': f'High missing: {missing_percent:.1f}%',
|
| 676 |
+
'recommended_method': 'knn'
|
| 677 |
+
}
|
| 678 |
+
elif missing_percent > 5:
|
| 679 |
+
rules[col] = {
|
| 680 |
+
'action': 'standard_imputation',
|
| 681 |
+
'reason': f'Moderate missing: {missing_percent:.1f}%',
|
| 682 |
+
'recommended_method': 'interpolate'
|
| 683 |
+
}
|
| 684 |
+
elif missing_percent > 0:
|
| 685 |
+
rules[col] = {
|
| 686 |
+
'action': 'simple_imputation',
|
| 687 |
+
'reason': f'Low missing: {missing_percent:.1f}%',
|
| 688 |
+
'recommended_method': 'ffill'
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
return rules
|
| 692 |
+
|
| 693 |
+
def get_report(self) -> Dict:
|
| 694 |
+
"""Get missing values report"""
|
| 695 |
+
return {
|
| 696 |
+
'missing_info': self.missing_info,
|
| 697 |
+
'handling_methods': self.handling_methods,
|
| 698 |
+
'missing_patterns': self.missing_patterns,
|
| 699 |
+
'validation_rules': self.create_validation_rules()
|
| 700 |
+
}
|
outliers/__init__.py
ADDED
|
File without changes
|
outliers/outlier_analyzer.py
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 4: OUTLIER ANALYSER
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
from venv import logger
|
| 6 |
+
|
| 7 |
+
from config.config import Config
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from sklearn.neighbors import LocalOutlierFactor
|
| 12 |
+
from sklearn.covariance import EllipticEnvelope
|
| 13 |
+
from scipy import stats
|
| 14 |
+
|
| 15 |
+
class OutlierAnalyser:
|
| 16 |
+
"""Class for analysing and handling outliers"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: Config):
|
| 19 |
+
"""
|
| 20 |
+
Initialise outlier analyser
|
| 21 |
+
|
| 22 |
+
Parameters:
|
| 23 |
+
-----------
|
| 24 |
+
config : Config
|
| 25 |
+
Experiment configuration
|
| 26 |
+
"""
|
| 27 |
+
self.config = config
|
| 28 |
+
self.outlier_info = {}
|
| 29 |
+
self.handling_methods = {}
|
| 30 |
+
self.detection_methods = {}
|
| 31 |
+
self.outlier_models = {}
|
| 32 |
+
|
| 33 |
+
def analyse(
|
| 34 |
+
self,
|
| 35 |
+
data: pd.DataFrame,
|
| 36 |
+
method: str = None,
|
| 37 |
+
columns: List[str] = None,
|
| 38 |
+
**kwargs
|
| 39 |
+
) -> Dict:
|
| 40 |
+
"""
|
| 41 |
+
Analyse outliers in data
|
| 42 |
+
|
| 43 |
+
Parameters:
|
| 44 |
+
-----------
|
| 45 |
+
data : pd.DataFrame
|
| 46 |
+
Input data
|
| 47 |
+
method : str, optional
|
| 48 |
+
Detection method. If None, uses configuration value.
|
| 49 |
+
columns : List[str], optional
|
| 50 |
+
List of columns to analyse. If None, uses all numeric columns.
|
| 51 |
+
**kwargs : dict
|
| 52 |
+
Additional parameters for method
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
--------
|
| 56 |
+
Dict
|
| 57 |
+
Information about outliers
|
| 58 |
+
"""
|
| 59 |
+
logger.info("\n" + "="*80)
|
| 60 |
+
logger.info("OUTLIER ANALYSIS")
|
| 61 |
+
logger.info("="*80)
|
| 62 |
+
|
| 63 |
+
method = method or self.config.outlier_method
|
| 64 |
+
if columns is None:
|
| 65 |
+
columns = data.select_dtypes(include=[np.number]).columns
|
| 66 |
+
|
| 67 |
+
outliers_info = {}
|
| 68 |
+
|
| 69 |
+
# Apply various detection methods
|
| 70 |
+
detection_results = {}
|
| 71 |
+
|
| 72 |
+
# 1. Statistical methods
|
| 73 |
+
if method in ['iqr', 'zscore', 'sigma', 'all']:
|
| 74 |
+
detection_results.update(self._statistical_methods(data, columns, method, **kwargs))
|
| 75 |
+
|
| 76 |
+
# 2. ML methods
|
| 77 |
+
if method in ['lof', 'isolation_forest', 'elliptic_envelope', 'all']:
|
| 78 |
+
detection_results.update(self._ml_methods(data, columns, method, **kwargs))
|
| 79 |
+
|
| 80 |
+
# 3. Temporal methods
|
| 81 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 82 |
+
detection_results.update(self._temporal_methods(data, columns, **kwargs))
|
| 83 |
+
|
| 84 |
+
# Aggregate results
|
| 85 |
+
for col in columns:
|
| 86 |
+
if col in detection_results:
|
| 87 |
+
# Combine results from different methods
|
| 88 |
+
combined_mask = self._combine_detection_methods(detection_results, col)
|
| 89 |
+
|
| 90 |
+
outliers_count = combined_mask.sum()
|
| 91 |
+
outliers_percent = (outliers_count / len(data)) * 100
|
| 92 |
+
|
| 93 |
+
# Detailed information
|
| 94 |
+
col_data = data[col].dropna()
|
| 95 |
+
stats = {
|
| 96 |
+
'mean': float(col_data.mean()),
|
| 97 |
+
'std': float(col_data.std()),
|
| 98 |
+
'median': float(col_data.median()),
|
| 99 |
+
'q1': float(col_data.quantile(0.25)),
|
| 100 |
+
'q3': float(col_data.quantile(0.75)),
|
| 101 |
+
'min': float(col_data.min()),
|
| 102 |
+
'max': float(col_data.max()),
|
| 103 |
+
'skewness': float(col_data.skew()),
|
| 104 |
+
'kurtosis': float(col_data.kurtosis())
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
outliers_info[col] = {
|
| 108 |
+
'method': method,
|
| 109 |
+
'statistics': stats,
|
| 110 |
+
'outliers_count': int(outliers_count),
|
| 111 |
+
'outliers_percent': float(outliers_percent),
|
| 112 |
+
'outlier_indices': data[combined_mask].index.tolist() if outliers_count > 0 else [],
|
| 113 |
+
'outlier_values': data.loc[combined_mask, col].tolist() if outliers_count > 0 else [],
|
| 114 |
+
'detection_methods': {
|
| 115 |
+
name: {
|
| 116 |
+
'count': int(mask.sum()),
|
| 117 |
+
'percent': float(mask.sum() / len(data) * 100)
|
| 118 |
+
}
|
| 119 |
+
for name, mask in detection_results[col].items()
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
logger.info(f"{col}: {outliers_count} outliers ({outliers_percent:.2f}%)")
|
| 124 |
+
|
| 125 |
+
self.outlier_info = outliers_info
|
| 126 |
+
self.detection_methods = detection_results
|
| 127 |
+
|
| 128 |
+
# Visualisation
|
| 129 |
+
if self.config.save_plots and len(columns) > 0:
|
| 130 |
+
self._plot_outlier_analysis(data, columns, outliers_info)
|
| 131 |
+
|
| 132 |
+
return outliers_info
|
| 133 |
+
|
| 134 |
+
def _statistical_methods(
|
| 135 |
+
self,
|
| 136 |
+
data: pd.DataFrame,
|
| 137 |
+
columns: List[str],
|
| 138 |
+
method: str,
|
| 139 |
+
**kwargs
|
| 140 |
+
) -> Dict:
|
| 141 |
+
"""Statistical outlier detection methods"""
|
| 142 |
+
results = {}
|
| 143 |
+
|
| 144 |
+
for col in columns:
|
| 145 |
+
col_results = {}
|
| 146 |
+
series = data[col].dropna()
|
| 147 |
+
|
| 148 |
+
if len(series) < 3:
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
# IQR method
|
| 152 |
+
if method in ['iqr', 'all']:
|
| 153 |
+
q1 = series.quantile(0.25)
|
| 154 |
+
q3 = series.quantile(0.75)
|
| 155 |
+
iqr = q3 - q1
|
| 156 |
+
lower_bound = q1 - self.config.outlier_alpha * iqr
|
| 157 |
+
upper_bound = q3 + self.config.outlier_alpha * iqr
|
| 158 |
+
|
| 159 |
+
iqr_mask = (data[col] < lower_bound) | (data[col] > upper_bound)
|
| 160 |
+
col_results['iqr'] = iqr_mask
|
| 161 |
+
|
| 162 |
+
# Z-score method
|
| 163 |
+
if method in ['zscore', 'sigma', 'all']:
|
| 164 |
+
z_threshold = kwargs.get('z_threshold', 3)
|
| 165 |
+
z_scores = np.abs((data[col] - series.mean()) / series.std())
|
| 166 |
+
z_mask = z_scores > z_threshold
|
| 167 |
+
col_results['zscore'] = z_mask
|
| 168 |
+
|
| 169 |
+
# Modified Z-score method
|
| 170 |
+
if method in ['zscore', 'all']:
|
| 171 |
+
median = series.median()
|
| 172 |
+
mad = np.median(np.abs(series - median))
|
| 173 |
+
if mad != 0:
|
| 174 |
+
modified_z_scores = 0.6745 * (data[col] - median) / mad
|
| 175 |
+
mz_mask = np.abs(modified_z_scores) > 3.5
|
| 176 |
+
col_results['modified_zscore'] = mz_mask
|
| 177 |
+
|
| 178 |
+
# Tukey's fences
|
| 179 |
+
if method in ['iqr', 'all']:
|
| 180 |
+
inner_lower = q1 - 1.5 * iqr
|
| 181 |
+
inner_upper = q3 + 1.5 * iqr
|
| 182 |
+
outer_lower = q1 - 3 * iqr
|
| 183 |
+
outer_upper = q3 + 3 * iqr
|
| 184 |
+
|
| 185 |
+
mild_mask = ((data[col] < inner_lower) | (data[col] > inner_upper)) & \
|
| 186 |
+
((data[col] >= outer_lower) & (data[col] <= outer_upper))
|
| 187 |
+
extreme_mask = (data[col] < outer_lower) | (data[col] > outer_upper)
|
| 188 |
+
|
| 189 |
+
col_results['tukey_mild'] = mild_mask
|
| 190 |
+
col_results['tukey_extreme'] = extreme_mask
|
| 191 |
+
|
| 192 |
+
results[col] = col_results
|
| 193 |
+
|
| 194 |
+
return results
|
| 195 |
+
|
| 196 |
+
def _ml_methods(
|
| 197 |
+
self,
|
| 198 |
+
data: pd.DataFrame,
|
| 199 |
+
columns: List[str],
|
| 200 |
+
method: str,
|
| 201 |
+
**kwargs
|
| 202 |
+
) -> Dict:
|
| 203 |
+
"""ML outlier detection methods"""
|
| 204 |
+
results = {}
|
| 205 |
+
|
| 206 |
+
numeric_data = data[columns].dropna()
|
| 207 |
+
|
| 208 |
+
if len(numeric_data) < 10:
|
| 209 |
+
return results
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
# Local Outlier Factor
|
| 213 |
+
if method in ['lof', 'all']:
|
| 214 |
+
lof = LocalOutlierFactor(
|
| 215 |
+
contamination=self.config.outlier_contamination,
|
| 216 |
+
n_neighbors=kwargs.get('n_neighbors', 20)
|
| 217 |
+
)
|
| 218 |
+
lof_labels = lof.fit_predict(numeric_data)
|
| 219 |
+
lof_mask = pd.Series(lof_labels == -1, index=numeric_data.index)
|
| 220 |
+
|
| 221 |
+
for col in columns:
|
| 222 |
+
if col in numeric_data.columns:
|
| 223 |
+
if col not in results:
|
| 224 |
+
results[col] = {}
|
| 225 |
+
results[col]['lof'] = lof_mask
|
| 226 |
+
|
| 227 |
+
# Elliptic Envelope
|
| 228 |
+
if method in ['elliptic_envelope', 'all']:
|
| 229 |
+
try:
|
| 230 |
+
envelope = EllipticEnvelope(
|
| 231 |
+
contamination=self.config.outlier_contamination,
|
| 232 |
+
random_state=42
|
| 233 |
+
)
|
| 234 |
+
envelope_labels = envelope.fit_predict(numeric_data)
|
| 235 |
+
envelope_mask = pd.Series(envelope_labels == -1, index=numeric_data.index)
|
| 236 |
+
|
| 237 |
+
for col in columns:
|
| 238 |
+
if col in numeric_data.columns:
|
| 239 |
+
if col not in results:
|
| 240 |
+
results[col] = {}
|
| 241 |
+
results[col]['elliptic_envelope'] = envelope_mask
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.warning(f"Elliptic Envelope failed: {e}")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.warning(f"ML outlier detection methods failed: {e}")
|
| 247 |
+
|
| 248 |
+
return results
|
| 249 |
+
|
| 250 |
+
def _temporal_methods(
|
| 251 |
+
self,
|
| 252 |
+
data: pd.DataFrame,
|
| 253 |
+
columns: List[str],
|
| 254 |
+
**kwargs
|
| 255 |
+
) -> Dict:
|
| 256 |
+
"""Outlier detection methods for time series"""
|
| 257 |
+
results = {}
|
| 258 |
+
|
| 259 |
+
for col in columns:
|
| 260 |
+
col_results = {}
|
| 261 |
+
series = data[col].dropna()
|
| 262 |
+
|
| 263 |
+
if len(series) < 30:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
# Rolling statistics method
|
| 267 |
+
window = kwargs.get('temporal_window', 30)
|
| 268 |
+
rolling_mean = series.rolling(window=window, center=True).mean()
|
| 269 |
+
rolling_std = series.rolling(window=window, center=True).std()
|
| 270 |
+
|
| 271 |
+
# Outliers relative to moving average
|
| 272 |
+
threshold = kwargs.get('temporal_threshold', 3)
|
| 273 |
+
temporal_mask = np.abs(series - rolling_mean) > (threshold * rolling_std)
|
| 274 |
+
col_results['temporal'] = temporal_mask
|
| 275 |
+
|
| 276 |
+
# Seasonal detrending + outlier detection
|
| 277 |
+
try:
|
| 278 |
+
# Simple seasonal detrending
|
| 279 |
+
if len(series) > 365:
|
| 280 |
+
seasonal_period = kwargs.get('seasonal_period', 365)
|
| 281 |
+
seasonal_mean = series.rolling(window=seasonal_period, center=True).mean()
|
| 282 |
+
detrended = series - seasonal_mean
|
| 283 |
+
|
| 284 |
+
# Outliers in detrended series
|
| 285 |
+
q1 = detrended.quantile(0.25)
|
| 286 |
+
q3 = detrended.quantile(0.75)
|
| 287 |
+
iqr = q3 - q1
|
| 288 |
+
seasonal_lower = q1 - 3 * iqr
|
| 289 |
+
seasonal_upper = q3 + 3 * iqr
|
| 290 |
+
|
| 291 |
+
seasonal_mask = (detrended < seasonal_lower) | (detrended > seasonal_upper)
|
| 292 |
+
col_results['seasonal'] = seasonal_mask
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.debug(f"Seasonal outlier detection failed for {col}: {e}")
|
| 295 |
+
|
| 296 |
+
results[col] = col_results
|
| 297 |
+
|
| 298 |
+
return results
|
| 299 |
+
|
| 300 |
+
def _combine_detection_methods(
|
| 301 |
+
self,
|
| 302 |
+
detection_results: Dict,
|
| 303 |
+
column: str
|
| 304 |
+
) -> pd.Series:
|
| 305 |
+
"""Combine results from different detection methods"""
|
| 306 |
+
if column not in detection_results:
|
| 307 |
+
return pd.Series(False, index=pd.RangeIndex(0))
|
| 308 |
+
|
| 309 |
+
methods = detection_results[column]
|
| 310 |
+
combined_mask = None
|
| 311 |
+
|
| 312 |
+
for method_name, mask in methods.items():
|
| 313 |
+
if combined_mask is None:
|
| 314 |
+
combined_mask = mask.copy()
|
| 315 |
+
else:
|
| 316 |
+
# Combine via OR (outlier by any method)
|
| 317 |
+
combined_mask = combined_mask | mask
|
| 318 |
+
|
| 319 |
+
return combined_mask.fillna(False)
|
| 320 |
+
|
| 321 |
+
def _plot_outlier_analysis(
|
| 322 |
+
self,
|
| 323 |
+
data: pd.DataFrame,
|
| 324 |
+
columns: List[str],
|
| 325 |
+
outliers_info: Dict
|
| 326 |
+
) -> None:
|
| 327 |
+
"""Visualise outlier analysis"""
|
| 328 |
+
n_cols = min(len(columns), 4)
|
| 329 |
+
n_rows = (len(columns) + n_cols - 1) // n_cols
|
| 330 |
+
|
| 331 |
+
fig = plt.figure(figsize=(16, 4 * n_rows))
|
| 332 |
+
gs = fig.add_gridspec(n_rows, n_cols)
|
| 333 |
+
|
| 334 |
+
for idx, col in enumerate(columns):
|
| 335 |
+
if col not in outliers_info:
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
row = idx // n_cols
|
| 339 |
+
col_idx = idx % n_cols
|
| 340 |
+
|
| 341 |
+
ax = fig.add_subplot(gs[row, col_idx])
|
| 342 |
+
|
| 343 |
+
# Data
|
| 344 |
+
series = data[col].dropna()
|
| 345 |
+
|
| 346 |
+
# 1. Box plot
|
| 347 |
+
bp = ax.boxplot(
|
| 348 |
+
series.values,
|
| 349 |
+
vert=True,
|
| 350 |
+
patch_artist=True,
|
| 351 |
+
widths=0.6,
|
| 352 |
+
showfliers=False
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Colours for box plot
|
| 356 |
+
bp['boxes'][0].set_facecolor('lightblue')
|
| 357 |
+
bp['medians'][0].set_color('red')
|
| 358 |
+
bp['whiskers'][0].set_color('black')
|
| 359 |
+
bp['whiskers'][1].set_color('black')
|
| 360 |
+
bp['caps'][0].set_color('black')
|
| 361 |
+
bp['caps'][1].set_color('black')
|
| 362 |
+
|
| 363 |
+
# 2. Outliers
|
| 364 |
+
if outliers_info[col]['outliers_count'] > 0:
|
| 365 |
+
outlier_indices = outliers_info[col]['outlier_indices']
|
| 366 |
+
outlier_values = outliers_info[col]['outlier_values']
|
| 367 |
+
|
| 368 |
+
# Convert indices to positions for box plot
|
| 369 |
+
jitter = np.random.normal(0, 0.05, len(outlier_values))
|
| 370 |
+
|
| 371 |
+
ax.scatter(
|
| 372 |
+
np.ones(len(outlier_values)) + jitter,
|
| 373 |
+
outlier_values,
|
| 374 |
+
color='red',
|
| 375 |
+
alpha=0.6,
|
| 376 |
+
s=30,
|
| 377 |
+
edgecolors='black',
|
| 378 |
+
label=f'Outliers ({outliers_info[col]["outliers_count"]})'
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# 3. Histogram on same plot
|
| 382 |
+
ax2 = ax.twinx()
|
| 383 |
+
ax2.hist(
|
| 384 |
+
series.values,
|
| 385 |
+
bins=30,
|
| 386 |
+
alpha=0.3,
|
| 387 |
+
color='green',
|
| 388 |
+
density=True
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# 4. Normal distribution for comparison
|
| 392 |
+
if len(series) > 10:
|
| 393 |
+
xmin, xmax = ax.get_xlim()
|
| 394 |
+
x = np.linspace(series.min(), series.max(), 100)
|
| 395 |
+
mean = series.mean()
|
| 396 |
+
std = series.std()
|
| 397 |
+
|
| 398 |
+
if std > 0:
|
| 399 |
+
p = stats.norm.pdf(x, mean, std)
|
| 400 |
+
ax2.plot(x, p, 'k--', linewidth=1, label='Normal distribution')
|
| 401 |
+
|
| 402 |
+
ax.set_title(f'{col}\nOutliers: {outliers_info[col]["outliers_count"]} ({outliers_info[col]["outliers_percent"]:.1f}%)')
|
| 403 |
+
ax.set_ylabel('Value')
|
| 404 |
+
ax.grid(True, alpha=0.3)
|
| 405 |
+
|
| 406 |
+
# Legend
|
| 407 |
+
if outliers_info[col]['outliers_count'] > 0:
|
| 408 |
+
ax.legend(loc='upper right', fontsize=8)
|
| 409 |
+
ax2.legend(loc='upper left', fontsize=8)
|
| 410 |
+
|
| 411 |
+
plt.tight_layout()
|
| 412 |
+
plt.savefig(
|
| 413 |
+
f'{self.config.results_dir}/plots/outliers_analysis.png',
|
| 414 |
+
dpi=300,
|
| 415 |
+
bbox_inches='tight'
|
| 416 |
+
)
|
| 417 |
+
plt.show()
|
| 418 |
+
|
| 419 |
+
# Additional plots for time series
|
| 420 |
+
if isinstance(data.index, pd.DatetimeIndex) and len(columns) > 0:
|
| 421 |
+
self._plot_temporal_outliers(data, columns, outliers_info)
|
| 422 |
+
|
| 423 |
+
def _plot_temporal_outliers(
|
| 424 |
+
self,
|
| 425 |
+
data: pd.DataFrame,
|
| 426 |
+
columns: List[str],
|
| 427 |
+
outliers_info: Dict
|
| 428 |
+
) -> None:
|
| 429 |
+
"""Visualise outliers over time"""
|
| 430 |
+
n_plots = min(len(columns), 3)
|
| 431 |
+
|
| 432 |
+
fig, axes = plt.subplots(n_plots, 1, figsize=(14, 4 * n_plots))
|
| 433 |
+
if n_plots == 1:
|
| 434 |
+
axes = [axes]
|
| 435 |
+
|
| 436 |
+
for idx, (col, ax) in enumerate(zip(columns[:n_plots], axes)):
|
| 437 |
+
if col not in outliers_info:
|
| 438 |
+
continue
|
| 439 |
+
|
| 440 |
+
# Time series
|
| 441 |
+
ax.plot(data.index, data[col], alpha=0.7, linewidth=1, label='Original series')
|
| 442 |
+
|
| 443 |
+
# Outliers
|
| 444 |
+
if outliers_info[col]['outliers_count'] > 0:
|
| 445 |
+
outlier_indices = outliers_info[col]['outlier_indices']
|
| 446 |
+
outlier_values = outliers_info[col]['outlier_values']
|
| 447 |
+
|
| 448 |
+
ax.scatter(
|
| 449 |
+
outlier_indices,
|
| 450 |
+
outlier_values,
|
| 451 |
+
color='red',
|
| 452 |
+
s=40,
|
| 453 |
+
edgecolors='black',
|
| 454 |
+
zorder=5,
|
| 455 |
+
label='Outliers'
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Moving average
|
| 459 |
+
if len(data) > 30:
|
| 460 |
+
rolling_mean = data[col].rolling(window=30, center=True).mean()
|
| 461 |
+
ax.plot(data.index, rolling_mean, 'orange', linewidth=2, label='Moving average (30)')
|
| 462 |
+
|
| 463 |
+
ax.set_title(f'Outliers over time: {col}')
|
| 464 |
+
ax.set_xlabel('Date')
|
| 465 |
+
ax.set_ylabel(col)
|
| 466 |
+
ax.legend(fontsize=8)
|
| 467 |
+
ax.grid(True, alpha=0.3)
|
| 468 |
+
|
| 469 |
+
plt.tight_layout()
|
| 470 |
+
plt.savefig(
|
| 471 |
+
f'{self.config.results_dir}/plots/temporal_outliers.png',
|
| 472 |
+
dpi=300,
|
| 473 |
+
bbox_inches='tight'
|
| 474 |
+
)
|
| 475 |
+
plt.show()
|
| 476 |
+
|
| 477 |
+
def handle(
|
| 478 |
+
self,
|
| 479 |
+
data: pd.DataFrame,
|
| 480 |
+
method: str = 'clip',
|
| 481 |
+
strategy: str = 'columnwise',
|
| 482 |
+
**kwargs
|
| 483 |
+
) -> pd.DataFrame:
|
| 484 |
+
"""
|
| 485 |
+
Handle outliers
|
| 486 |
+
|
| 487 |
+
Parameters:
|
| 488 |
+
-----------
|
| 489 |
+
data : pd.DataFrame
|
| 490 |
+
Input data
|
| 491 |
+
method : str
|
| 492 |
+
Handling method: 'clip', 'remove', 'mean', 'median', 'winsorize', 'transform', 'impute'
|
| 493 |
+
strategy : str
|
| 494 |
+
Strategy: 'columnwise', 'global', 'adaptive'
|
| 495 |
+
**kwargs : dict
|
| 496 |
+
Additional parameters for method
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
--------
|
| 500 |
+
pd.DataFrame
|
| 501 |
+
Data with handled outliers
|
| 502 |
+
"""
|
| 503 |
+
logger.info("\n" + "="*80)
|
| 504 |
+
logger.info("HANDLING OUTLIERS")
|
| 505 |
+
logger.info("="*80)
|
| 506 |
+
|
| 507 |
+
if not self.outlier_info:
|
| 508 |
+
logger.warning("⚠ Perform outlier analysis first")
|
| 509 |
+
return data
|
| 510 |
+
|
| 511 |
+
data_processed = data.copy()
|
| 512 |
+
methods_applied = {}
|
| 513 |
+
|
| 514 |
+
for col, info in self.outlier_info.items():
|
| 515 |
+
if col not in data_processed.columns:
|
| 516 |
+
continue
|
| 517 |
+
|
| 518 |
+
outliers_count = info['outliers_count']
|
| 519 |
+
|
| 520 |
+
if outliers_count > 0:
|
| 521 |
+
# Create outlier mask
|
| 522 |
+
outlier_mask = pd.Series(False, index=data_processed.index)
|
| 523 |
+
if info['outlier_indices']:
|
| 524 |
+
outlier_indices = [idx for idx in info['outlier_indices'] if idx in data_processed.index]
|
| 525 |
+
outlier_mask.loc[outlier_indices] = True
|
| 526 |
+
|
| 527 |
+
# Determine boundaries
|
| 528 |
+
stats = info['statistics']
|
| 529 |
+
q1, q3 = stats['q1'], stats['q3']
|
| 530 |
+
iqr = q3 - q1
|
| 531 |
+
lower_bound = q1 - self.config.outlier_alpha * iqr
|
| 532 |
+
upper_bound = q3 + self.config.outlier_alpha * iqr
|
| 533 |
+
|
| 534 |
+
if method == 'clip':
|
| 535 |
+
# Clip values to boundaries
|
| 536 |
+
data_processed[col] = data_processed[col].clip(
|
| 537 |
+
lower=lower_bound,
|
| 538 |
+
upper=upper_bound
|
| 539 |
+
)
|
| 540 |
+
method_used = 'clipping'
|
| 541 |
+
affected = outliers_count
|
| 542 |
+
|
| 543 |
+
elif method == 'remove':
|
| 544 |
+
# Remove rows with outliers
|
| 545 |
+
data_processed = data_processed[~outlier_mask]
|
| 546 |
+
method_used = 'removal'
|
| 547 |
+
affected = outliers_count
|
| 548 |
+
|
| 549 |
+
elif method == 'mean':
|
| 550 |
+
# Replace outliers with mean value
|
| 551 |
+
mean_val = data_processed[col].mean()
|
| 552 |
+
data_processed.loc[outlier_mask, col] = mean_val
|
| 553 |
+
method_used = 'mean imputation'
|
| 554 |
+
affected = outliers_count
|
| 555 |
+
|
| 556 |
+
elif method == 'median':
|
| 557 |
+
# Replace outliers with median
|
| 558 |
+
median_val = data_processed[col].median()
|
| 559 |
+
data_processed.loc[outlier_mask, col] = median_val
|
| 560 |
+
method_used = 'median imputation'
|
| 561 |
+
affected = outliers_count
|
| 562 |
+
|
| 563 |
+
elif method == 'winsorize':
|
| 564 |
+
# Winsorisation
|
| 565 |
+
data_processed[col] = self._winsorize_series(
|
| 566 |
+
data_processed[col],
|
| 567 |
+
limits=kwargs.get('limits', (0.05, 0.05))
|
| 568 |
+
)
|
| 569 |
+
method_used = 'winsorization'
|
| 570 |
+
affected = outliers_count
|
| 571 |
+
|
| 572 |
+
elif method == 'transform':
|
| 573 |
+
# Transformation to reduce outlier impact
|
| 574 |
+
transform_method = kwargs.get('transform_method', 'log')
|
| 575 |
+
data_processed[col] = self._transform_series(
|
| 576 |
+
data_processed[col],
|
| 577 |
+
method=transform_method
|
| 578 |
+
)
|
| 579 |
+
method_used = f'{transform_method} transformation'
|
| 580 |
+
affected = 'all' # Transformation applied to all values
|
| 581 |
+
|
| 582 |
+
elif method == 'impute':
|
| 583 |
+
# Smart outlier imputation
|
| 584 |
+
impute_method = kwargs.get('impute_method', 'neighbors')
|
| 585 |
+
data_processed[col] = self._impute_outliers(
|
| 586 |
+
data_processed[col],
|
| 587 |
+
outlier_mask,
|
| 588 |
+
method=impute_method,
|
| 589 |
+
**kwargs
|
| 590 |
+
)
|
| 591 |
+
method_used = f'{impute_method} imputation'
|
| 592 |
+
affected = outliers_count
|
| 593 |
+
|
| 594 |
+
elif method == 'adaptive':
|
| 595 |
+
# Adaptive handling
|
| 596 |
+
data_processed[col] = self._adaptive_outlier_handling(
|
| 597 |
+
data_processed[col],
|
| 598 |
+
outlier_mask,
|
| 599 |
+
**kwargs
|
| 600 |
+
)
|
| 601 |
+
method_used = 'adaptive handling'
|
| 602 |
+
affected = outliers_count
|
| 603 |
+
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError(f"Unknown method: {method}")
|
| 606 |
+
|
| 607 |
+
methods_applied[col] = {
|
| 608 |
+
'method': method_used,
|
| 609 |
+
'outliers_before': outliers_count,
|
| 610 |
+
'affected': affected,
|
| 611 |
+
'bounds': {
|
| 612 |
+
'lower': float(lower_bound),
|
| 613 |
+
'upper': float(upper_bound)
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
logger.info(f" {col}: {outliers_count} outliers handled ({method_used})")
|
| 618 |
+
|
| 619 |
+
self.handling_methods = methods_applied
|
| 620 |
+
|
| 621 |
+
# Handling statistics
|
| 622 |
+
total_outliers = sum(info['outliers_count'] for info in self.outlier_info.values())
|
| 623 |
+
total_affected = sum(method['affected'] for method in methods_applied.values()
|
| 624 |
+
if isinstance(method['affected'], (int, np.integer)))
|
| 625 |
+
|
| 626 |
+
logger.info(f"\n✓ {total_affected} out of {total_outliers} outliers handled")
|
| 627 |
+
logger.info(f" Data size before: {len(data)} rows")
|
| 628 |
+
logger.info(f" Data size after: {len(data_processed)} rows")
|
| 629 |
+
|
| 630 |
+
# Visualise results
|
| 631 |
+
if self.config.save_plots and methods_applied:
|
| 632 |
+
self._plot_outlier_handling_results(data, data_processed, methods_applied)
|
| 633 |
+
|
| 634 |
+
return data_processed
|
| 635 |
+
|
| 636 |
+
def _winsorize_series(
|
| 637 |
+
self,
|
| 638 |
+
series: pd.Series,
|
| 639 |
+
limits: Tuple[float, float] = (0.05, 0.05)
|
| 640 |
+
) -> pd.Series:
|
| 641 |
+
"""Winsorize series"""
|
| 642 |
+
from scipy.stats.mstats import winsorize
|
| 643 |
+
try:
|
| 644 |
+
winsorized = winsorize(series.values, limits=limits)
|
| 645 |
+
return pd.Series(winsorized, index=series.index)
|
| 646 |
+
except:
|
| 647 |
+
return series
|
| 648 |
+
|
| 649 |
+
def _transform_series(
|
| 650 |
+
self,
|
| 651 |
+
series: pd.Series,
|
| 652 |
+
method: str = 'log'
|
| 653 |
+
) -> pd.Series:
|
| 654 |
+
"""Transform series to reduce outlier impact"""
|
| 655 |
+
series_transformed = series.copy()
|
| 656 |
+
|
| 657 |
+
if method == 'log':
|
| 658 |
+
# Logarithmic transformation
|
| 659 |
+
min_val = series.min()
|
| 660 |
+
if min_val <= 0:
|
| 661 |
+
shift = abs(min_val) + 1
|
| 662 |
+
series_transformed = np.log(series + shift)
|
| 663 |
+
else:
|
| 664 |
+
series_transformed = np.log(series)
|
| 665 |
+
|
| 666 |
+
elif method == 'boxcox':
|
| 667 |
+
# Box-Cox transformation
|
| 668 |
+
try:
|
| 669 |
+
from scipy.stats import boxcox
|
| 670 |
+
transformed, _ = boxcox(series - series.min() + 1)
|
| 671 |
+
series_transformed = pd.Series(transformed, index=series.index)
|
| 672 |
+
except:
|
| 673 |
+
logger.warning("Box-Cox transformation failed, using log")
|
| 674 |
+
return self._transform_series(series, 'log')
|
| 675 |
+
|
| 676 |
+
elif method == 'sqrt':
|
| 677 |
+
# Square root
|
| 678 |
+
min_val = series.min()
|
| 679 |
+
if min_val < 0:
|
| 680 |
+
series_transformed = np.sqrt(series - min_val)
|
| 681 |
+
else:
|
| 682 |
+
series_transformed = np.sqrt(series)
|
| 683 |
+
|
| 684 |
+
elif method == 'yeojohnson':
|
| 685 |
+
# Yeo-Johnson transformation
|
| 686 |
+
try:
|
| 687 |
+
from scipy.stats import yeojohnson
|
| 688 |
+
transformed, _ = yeojohnson(series)
|
| 689 |
+
series_transformed = pd.Series(transformed, index=series.index)
|
| 690 |
+
except:
|
| 691 |
+
logger.warning("Yeo-Johnson transformation failed, using log")
|
| 692 |
+
return self._transform_series(series, 'log')
|
| 693 |
+
|
| 694 |
+
return series_transformed
|
| 695 |
+
|
| 696 |
+
def _impute_outliers(
|
| 697 |
+
self,
|
| 698 |
+
series: pd.Series,
|
| 699 |
+
outlier_mask: pd.Series,
|
| 700 |
+
method: str = 'neighbors',
|
| 701 |
+
**kwargs
|
| 702 |
+
) -> pd.Series:
|
| 703 |
+
"""Smart outlier imputation"""
|
| 704 |
+
series_imputed = series.copy()
|
| 705 |
+
|
| 706 |
+
if method == 'neighbors':
|
| 707 |
+
# Replace with mean of neighbouring values
|
| 708 |
+
for idx in series[outlier_mask].index:
|
| 709 |
+
if idx in series.index:
|
| 710 |
+
pos = series.index.get_loc(idx)
|
| 711 |
+
neighbours = []
|
| 712 |
+
|
| 713 |
+
# Find nearest non-outliers
|
| 714 |
+
for offset in range(1, 6):
|
| 715 |
+
if pos - offset >= 0 and not outlier_mask.iloc[pos - offset]:
|
| 716 |
+
neighbours.append(series.iloc[pos - offset])
|
| 717 |
+
break
|
| 718 |
+
|
| 719 |
+
for offset in range(1, 6):
|
| 720 |
+
if pos + offset < len(series) and not outlier_mask.iloc[pos + offset]:
|
| 721 |
+
neighbours.append(series.iloc[pos + offset])
|
| 722 |
+
break
|
| 723 |
+
|
| 724 |
+
if neighbours:
|
| 725 |
+
series_imputed.loc[idx] = np.mean(neighbours)
|
| 726 |
+
|
| 727 |
+
elif method == 'interpolate':
|
| 728 |
+
# Interpolation
|
| 729 |
+
series_imputed = series.mask(outlier_mask).interpolate()
|
| 730 |
+
|
| 731 |
+
elif method == 'rolling':
|
| 732 |
+
# Replace with moving average
|
| 733 |
+
window = kwargs.get('window', 5)
|
| 734 |
+
rolling_mean = series.rolling(window=window, center=True, min_periods=1).mean()
|
| 735 |
+
series_imputed = series.mask(outlier_mask, rolling_mean)
|
| 736 |
+
|
| 737 |
+
return series_imputed
|
| 738 |
+
|
| 739 |
+
def _adaptive_outlier_handling(
|
| 740 |
+
self,
|
| 741 |
+
series: pd.Series,
|
| 742 |
+
outlier_mask: pd.Series,
|
| 743 |
+
**kwargs
|
| 744 |
+
) -> pd.Series:
|
| 745 |
+
"""Adaptive outlier handling"""
|
| 746 |
+
series_processed = series.copy()
|
| 747 |
+
outlier_indices = series[outlier_mask].index
|
| 748 |
+
|
| 749 |
+
for idx in outlier_indices:
|
| 750 |
+
if idx in series.index:
|
| 751 |
+
value = series.loc[idx]
|
| 752 |
+
stats = self.outlier_info.get(series.name, {}).get('statistics', {})
|
| 753 |
+
|
| 754 |
+
# Determine outlier type
|
| 755 |
+
q1 = stats.get('q1', series.quantile(0.25))
|
| 756 |
+
q3 = stats.get('q3', series.quantile(0.75))
|
| 757 |
+
iqr = q3 - q1
|
| 758 |
+
|
| 759 |
+
if value < q1 - 3 * iqr:
|
| 760 |
+
# Extreme low outlier
|
| 761 |
+
series_processed.loc[idx] = q1 - 1.5 * iqr
|
| 762 |
+
elif value > q3 + 3 * iqr:
|
| 763 |
+
# Extreme high outlier
|
| 764 |
+
series_processed.loc[idx] = q3 + 1.5 * iqr
|
| 765 |
+
else:
|
| 766 |
+
# Moderate outlier
|
| 767 |
+
pos = series.index.get_loc(idx)
|
| 768 |
+
# Use linear interpolation
|
| 769 |
+
if pos > 0 and pos < len(series) - 1:
|
| 770 |
+
series_processed.loc[idx] = (series.iloc[pos-1] + series.iloc[pos+1]) / 2
|
| 771 |
+
|
| 772 |
+
return series_processed
|
| 773 |
+
|
| 774 |
+
def _plot_outlier_handling_results(
|
| 775 |
+
self,
|
| 776 |
+
original_data: pd.DataFrame,
|
| 777 |
+
processed_data: pd.DataFrame,
|
| 778 |
+
methods_applied: Dict
|
| 779 |
+
) -> None:
|
| 780 |
+
"""Visualise outlier handling results"""
|
| 781 |
+
cols_to_plot = list(methods_applied.keys())[:3]
|
| 782 |
+
|
| 783 |
+
if not cols_to_plot:
|
| 784 |
+
return
|
| 785 |
+
|
| 786 |
+
fig, axes = plt.subplots(len(cols_to_plot), 2, figsize=(14, 4 * len(cols_to_plot)))
|
| 787 |
+
if len(cols_to_plot) == 1:
|
| 788 |
+
axes = axes.reshape(1, -1)
|
| 789 |
+
|
| 790 |
+
for idx, col in enumerate(cols_to_plot):
|
| 791 |
+
if col not in original_data.columns or col not in processed_data.columns:
|
| 792 |
+
continue
|
| 793 |
+
|
| 794 |
+
# Distribution before handling
|
| 795 |
+
axes[idx, 0].hist(original_data[col].dropna(), bins=30, alpha=0.5, label='Before', density=True)
|
| 796 |
+
axes[idx, 0].hist(processed_data[col].dropna(), bins=30, alpha=0.5, label='After', density=True)
|
| 797 |
+
axes[idx, 0].set_title(f'{col}: Distribution before/after')
|
| 798 |
+
axes[idx, 0].set_xlabel('Value')
|
| 799 |
+
axes[idx, 0].set_ylabel('Density')
|
| 800 |
+
axes[idx, 0].legend()
|
| 801 |
+
axes[idx, 0].grid(True, alpha=0.3)
|
| 802 |
+
|
| 803 |
+
# QQ plot for normality check
|
| 804 |
+
stats.probplot(original_data[col].dropna(), dist="norm", plot=axes[idx, 1])
|
| 805 |
+
axes[idx, 1].set_title(f'{col}: Q-Q plot (before handling)')
|
| 806 |
+
axes[idx, 1].grid(True, alpha=0.3)
|
| 807 |
+
|
| 808 |
+
plt.tight_layout()
|
| 809 |
+
plt.savefig(
|
| 810 |
+
f'{self.config.results_dir}/plots/outlier_handling_results.png',
|
| 811 |
+
dpi=300,
|
| 812 |
+
bbox_inches='tight'
|
| 813 |
+
)
|
| 814 |
+
plt.show()
|
| 815 |
+
|
| 816 |
+
def create_validation_rules(self) -> Dict:
|
| 817 |
+
"""Create validation rules based on outlier analysis"""
|
| 818 |
+
rules = {}
|
| 819 |
+
|
| 820 |
+
for col, info in self.outlier_info.items():
|
| 821 |
+
outliers_percent = info['outliers_percent']
|
| 822 |
+
skewness = info['statistics']['skewness']
|
| 823 |
+
|
| 824 |
+
rule = {
|
| 825 |
+
'outliers_percent': outliers_percent,
|
| 826 |
+
'skewness': skewness,
|
| 827 |
+
'recommended_action': 'none'
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
if outliers_percent > 10:
|
| 831 |
+
rule['recommended_action'] = 'aggressive_handling'
|
| 832 |
+
rule['reason'] = f'High outliers: {outliers_percent:.1f}%'
|
| 833 |
+
elif outliers_percent > 5:
|
| 834 |
+
rule['recommended_action'] = 'moderate_handling'
|
| 835 |
+
rule['reason'] = f'Moderate outliers: {outliers_percent:.1f}%'
|
| 836 |
+
elif outliers_percent > 1:
|
| 837 |
+
rule['recommended_action'] = 'conservative_handling'
|
| 838 |
+
rule['reason'] = f'Low outliers: {outliers_percent:.1f}%'
|
| 839 |
+
|
| 840 |
+
if abs(skewness) > 1:
|
| 841 |
+
rule['skewness_issue'] = True
|
| 842 |
+
rule['skewness_reason'] = f'Strong skewness: {skewness:.2f}'
|
| 843 |
+
if rule['recommended_action'] == 'none':
|
| 844 |
+
rule['recommended_action'] = 'transformation'
|
| 845 |
+
|
| 846 |
+
rules[col] = rule
|
| 847 |
+
|
| 848 |
+
return rules
|
| 849 |
+
|
| 850 |
+
def get_report(self) -> Dict:
|
| 851 |
+
"""Get outlier analysis report"""
|
| 852 |
+
return {
|
| 853 |
+
'outlier_info': self.outlier_info,
|
| 854 |
+
'handling_methods': self.handling_methods,
|
| 855 |
+
'detection_methods': self.detection_methods,
|
| 856 |
+
'validation_rules': self.create_validation_rules()
|
| 857 |
+
}
|
pipeline/__init__.py
ADDED
|
File without changes
|
pipeline/main_pipeline.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 14: MAIN PIPELINE
|
| 3 |
+
# ============================================
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import traceback
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
from venv import logger
|
| 10 |
+
from config.config import Config
|
| 11 |
+
from correlations.correlation_analyzer import CorrelationAnalyzer
|
| 12 |
+
from data_loader.data_loader import DataLoader
|
| 13 |
+
from decomposition.decomposer import TimeSeriesDecomposer
|
| 14 |
+
from feature_selection.feature_selector import FeatureSelector
|
| 15 |
+
from features.feature_engineer import FeatureEngineer
|
| 16 |
+
|
| 17 |
+
from missing_values.missing_analyzer import MissingValueAnalyser
|
| 18 |
+
from outliers.outlier_analyzer import OutlierAnalyser
|
| 19 |
+
from scaling.data_scaler import DataScaler
|
| 20 |
+
from splitting.data_splitter import DataSplitter
|
| 21 |
+
from stationarity.stationarity_checker import StationarityChecker
|
| 22 |
+
from validation.data_validator import DataValidator
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from visualization.visualization_manager import VisualisationManager
|
| 27 |
+
|
| 28 |
+
class EnhancedDataPreprocessingPipeline:
|
| 29 |
+
"""Enhanced main data preprocessing pipeline"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: Config):
|
| 32 |
+
"""
|
| 33 |
+
Initialise pipeline
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
-----------
|
| 37 |
+
config : Config
|
| 38 |
+
Experiment configuration
|
| 39 |
+
"""
|
| 40 |
+
self.config = config
|
| 41 |
+
self.data_loader = DataLoader(config)
|
| 42 |
+
self.missing_analyser = MissingValueAnalyser(config)
|
| 43 |
+
self.outlier_analyser = OutlierAnalyser(config)
|
| 44 |
+
self.feature_engineer = FeatureEngineer(config)
|
| 45 |
+
self.stationarity_checker = StationarityChecker(config)
|
| 46 |
+
self.decomposer = TimeSeriesDecomposer(config)
|
| 47 |
+
self.correlation_analyser = CorrelationAnalyzer(config)
|
| 48 |
+
self.data_splitter = DataSplitter(config)
|
| 49 |
+
self.data_scaler = DataScaler(config)
|
| 50 |
+
self.feature_selector = FeatureSelector(config)
|
| 51 |
+
self.data_validator = DataValidator(config)
|
| 52 |
+
self.visualisation_manager = VisualisationManager(config)
|
| 53 |
+
|
| 54 |
+
self.results = {}
|
| 55 |
+
self.processed_data = None
|
| 56 |
+
self.train_data = None
|
| 57 |
+
self.val_data = None
|
| 58 |
+
self.test_data = None
|
| 59 |
+
self.is_fitted = False
|
| 60 |
+
|
| 61 |
+
def run_full_pipeline(
|
| 62 |
+
self,
|
| 63 |
+
data_path: Optional[str] = None,
|
| 64 |
+
use_synthetic: bool = False,
|
| 65 |
+
save_intermediate: bool = True,
|
| 66 |
+
create_reports: bool = True
|
| 67 |
+
) -> pd.DataFrame:
|
| 68 |
+
"""
|
| 69 |
+
Run enhanced full data preprocessing pipeline
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
-----------
|
| 73 |
+
data_path : str, optional
|
| 74 |
+
Path to data. If None, uses configuration value.
|
| 75 |
+
use_synthetic : bool
|
| 76 |
+
Use synthetic data for testing
|
| 77 |
+
save_intermediate : bool
|
| 78 |
+
Save intermediate results
|
| 79 |
+
create_reports : bool
|
| 80 |
+
Create reports
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
--------
|
| 84 |
+
pd.DataFrame
|
| 85 |
+
Processed data
|
| 86 |
+
"""
|
| 87 |
+
logger.info("\n" + "="*80)
|
| 88 |
+
logger.info("RUNNING ENHANCED DATA PREPROCESSING PIPELINE")
|
| 89 |
+
logger.info("="*80)
|
| 90 |
+
|
| 91 |
+
start_time = datetime.now()
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Step 1: Data loading
|
| 95 |
+
logger.info("\n" + "="*80)
|
| 96 |
+
logger.info("STEP 1: DATA LOADING")
|
| 97 |
+
logger.info("="*80)
|
| 98 |
+
|
| 99 |
+
if use_synthetic:
|
| 100 |
+
data = self.data_loader.create_synthetic_data(
|
| 101 |
+
n_days=365*20,
|
| 102 |
+
trend_strength=0.01,
|
| 103 |
+
noise_std=10,
|
| 104 |
+
include_exogenous=True
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
data = self.data_loader.load_from_csv(
|
| 108 |
+
data_path=data_path,
|
| 109 |
+
parse_dates=['date']
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Check for date index
|
| 113 |
+
if not isinstance(data.index, pd.DatetimeIndex):
|
| 114 |
+
logger.warning("Index is not DatetimeIndex, setting...")
|
| 115 |
+
if 'date' in data.columns:
|
| 116 |
+
data = data.set_index('date')
|
| 117 |
+
logger.info("Index set from 'date' column")
|
| 118 |
+
|
| 119 |
+
self.results['data_loading'] = {
|
| 120 |
+
'shape': list(data.shape),
|
| 121 |
+
'columns': list(data.columns),
|
| 122 |
+
'date_range': {
|
| 123 |
+
'min': data.index.min().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None,
|
| 124 |
+
'max': data.index.max().strftime('%Y-%m-%d') if isinstance(data.index, pd.DatetimeIndex) else None
|
| 125 |
+
},
|
| 126 |
+
'is_datetime_index': isinstance(data.index, pd.DatetimeIndex)
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Save raw data information
|
| 130 |
+
self.data_loader.save_raw_data_info()
|
| 131 |
+
|
| 132 |
+
# Step 2: Raw data validation
|
| 133 |
+
logger.info("\n" + "="*80)
|
| 134 |
+
logger.info("STEP 2: RAW DATA VALIDATION")
|
| 135 |
+
logger.info("="*80)
|
| 136 |
+
|
| 137 |
+
raw_validation = self.data_validator.validate(
|
| 138 |
+
data, stage='raw', detailed=True
|
| 139 |
+
)
|
| 140 |
+
self.results['raw_validation'] = raw_validation
|
| 141 |
+
|
| 142 |
+
if raw_validation['status'] == 'FAIL':
|
| 143 |
+
logger.warning("⚠ Raw data has critical issues!")
|
| 144 |
+
if not self.config.enable_validation:
|
| 145 |
+
logger.warning("Validation disabled in configuration, continuing processing")
|
| 146 |
+
else:
|
| 147 |
+
logger.error("Pipeline interrupted due to data issues")
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
# Step 3: Missing values analysis and handling
|
| 151 |
+
logger.info("\n" + "="*80)
|
| 152 |
+
logger.info("STEP 3: MISSING VALUES HANDLING")
|
| 153 |
+
logger.info("="*80)
|
| 154 |
+
|
| 155 |
+
missing_info = self.missing_analyser.analyse(data, detailed=True)
|
| 156 |
+
self.results['missing_analysis'] = missing_info
|
| 157 |
+
|
| 158 |
+
# Handle missing values
|
| 159 |
+
data = self.missing_analyser.handle(
|
| 160 |
+
data,
|
| 161 |
+
method='interpolate',
|
| 162 |
+
strategy='columnwise'
|
| 163 |
+
)
|
| 164 |
+
self.results['missing_handling'] = self.missing_analyser.handling_methods
|
| 165 |
+
|
| 166 |
+
# Step 4: Outlier analysis and handling
|
| 167 |
+
logger.info("\n" + "="*80)
|
| 168 |
+
logger.info("STEP 4: OUTLIER HANDLING")
|
| 169 |
+
logger.info("="*80)
|
| 170 |
+
|
| 171 |
+
outlier_info = self.outlier_analyser.analyse(
|
| 172 |
+
data,
|
| 173 |
+
method=self.config.outlier_method,
|
| 174 |
+
columns=data.select_dtypes(include=[np.number]).columns.tolist()
|
| 175 |
+
)
|
| 176 |
+
self.results['outlier_analysis'] = outlier_info
|
| 177 |
+
|
| 178 |
+
# Handle outliers
|
| 179 |
+
data = self.outlier_analyser.handle(
|
| 180 |
+
data,
|
| 181 |
+
method='clip',
|
| 182 |
+
strategy='columnwise'
|
| 183 |
+
)
|
| 184 |
+
self.results['outlier_handling'] = self.outlier_analyser.handling_methods
|
| 185 |
+
|
| 186 |
+
# Step 5: Feature engineering
|
| 187 |
+
logger.info("\n" + "="*80)
|
| 188 |
+
logger.info("STEP 5: FEATURE ENGINEERING")
|
| 189 |
+
logger.info("="*80)
|
| 190 |
+
|
| 191 |
+
data = self.feature_engineer.create_all_features(data)
|
| 192 |
+
self.results['feature_engineering'] = self.feature_engineer.feature_info
|
| 193 |
+
|
| 194 |
+
# Check for data after feature engineering
|
| 195 |
+
if len(data) == 0:
|
| 196 |
+
logger.error("No data remaining after feature engineering!")
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
# Step 6: Stationarity analysis
|
| 200 |
+
logger.info("\n" + "="*80)
|
| 201 |
+
logger.info("STEP 6: STATIONARITY ANALYSIS")
|
| 202 |
+
logger.info("="*80)
|
| 203 |
+
|
| 204 |
+
stationarity_results = self.stationarity_checker.check(
|
| 205 |
+
data,
|
| 206 |
+
target_col=self.config.target_column,
|
| 207 |
+
make_stationary=True,
|
| 208 |
+
try_transformations=True
|
| 209 |
+
)
|
| 210 |
+
self.results['stationarity_analysis'] = stationarity_results
|
| 211 |
+
|
| 212 |
+
# Step 7: Time series decomposition
|
| 213 |
+
logger.info("\n" + "="*80)
|
| 214 |
+
logger.info("STEP 7: TIME SERIES DECOMPOSITION")
|
| 215 |
+
logger.info("="*80)
|
| 216 |
+
|
| 217 |
+
if isinstance(data.index, pd.DatetimeIndex) and len(data) > 365:
|
| 218 |
+
decomposition_results = self.decomposer.decompose(
|
| 219 |
+
data,
|
| 220 |
+
target_col=self.config.target_column,
|
| 221 |
+
method='stl',
|
| 222 |
+
period=self.config.seasonal_period
|
| 223 |
+
)
|
| 224 |
+
self.results['decomposition'] = decomposition_results
|
| 225 |
+
else:
|
| 226 |
+
logger.info("Skipped: insufficient data or no DatetimeIndex")
|
| 227 |
+
self.results['decomposition'] = {'skipped': 'insufficient data or no DatetimeIndex'}
|
| 228 |
+
|
| 229 |
+
# Step 8: Correlation analysis
|
| 230 |
+
logger.info("\n" + "="*80)
|
| 231 |
+
logger.info("STEP 8: CORRELATION ANALYSIS")
|
| 232 |
+
logger.info("="*80)
|
| 233 |
+
|
| 234 |
+
corr_matrix = self.correlation_analyser.analyze(
|
| 235 |
+
data,
|
| 236 |
+
target_col=self.config.target_column,
|
| 237 |
+
threshold=0.8,
|
| 238 |
+
detailed=True
|
| 239 |
+
)
|
| 240 |
+
self.results['correlation_analysis'] = self.correlation_analyser.get_report()
|
| 241 |
+
|
| 242 |
+
# Remove highly correlated features
|
| 243 |
+
if not corr_matrix.empty:
|
| 244 |
+
data = self.correlation_analyser.remove_highly_correlated(
|
| 245 |
+
data,
|
| 246 |
+
threshold=0.95,
|
| 247 |
+
method='variance',
|
| 248 |
+
keep_target=True
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
logger.warning("Correlation matrix empty, skipping feature removal")
|
| 252 |
+
|
| 253 |
+
# Step 9: Processed data validation
|
| 254 |
+
logger.info("\n" + "="*80)
|
| 255 |
+
logger.info("STEP 9: PROCESSED DATA VALIDATION")
|
| 256 |
+
logger.info("="*80)
|
| 257 |
+
|
| 258 |
+
processed_validation = self.data_validator.validate(
|
| 259 |
+
data, stage='processed', detailed=True
|
| 260 |
+
)
|
| 261 |
+
self.results['processed_validation'] = processed_validation
|
| 262 |
+
|
| 263 |
+
if processed_validation['status'] == 'FAIL':
|
| 264 |
+
logger.warning("⚠ Processed data failed validation!")
|
| 265 |
+
logger.warning("Continuing pipeline, but data quality may be low")
|
| 266 |
+
elif processed_validation['status'] == 'WARNING':
|
| 267 |
+
logger.warning("⚠ Processed data requires attention")
|
| 268 |
+
|
| 269 |
+
# Step 10: Data splitting
|
| 270 |
+
logger.info("\n" + "="*80)
|
| 271 |
+
logger.info("STEP 10: DATA SPLITTING")
|
| 272 |
+
logger.info("="*80)
|
| 273 |
+
|
| 274 |
+
train_data, val_data, test_data = self.data_splitter.split(
|
| 275 |
+
data,
|
| 276 |
+
method=self.config.split_method,
|
| 277 |
+
test_size=self.config.test_size,
|
| 278 |
+
validation_size=self.config.validation_size
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
self.train_data = train_data
|
| 282 |
+
self.val_data = val_data
|
| 283 |
+
self.test_data = test_data
|
| 284 |
+
|
| 285 |
+
self.results['data_splitting'] = self.data_splitter.split_info
|
| 286 |
+
|
| 287 |
+
# Step 11: Data scaling
|
| 288 |
+
logger.info("\n" + "="*80)
|
| 289 |
+
logger.info("STEP 11: DATA SCALING")
|
| 290 |
+
logger.info("="*80)
|
| 291 |
+
|
| 292 |
+
# Scale training data
|
| 293 |
+
train_data_scaled = self.data_scaler.fit_transform(
|
| 294 |
+
train_data,
|
| 295 |
+
method=self.config.scaling_method,
|
| 296 |
+
target_col=self.config.target_column,
|
| 297 |
+
fit_on_train=True
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Apply same scaling to validation and test data
|
| 301 |
+
val_data_scaled = self.data_scaler.transform(val_data)
|
| 302 |
+
test_data_scaled = self.data_scaler.transform(test_data)
|
| 303 |
+
|
| 304 |
+
self.train_data = train_data_scaled
|
| 305 |
+
self.val_data = val_data_scaled
|
| 306 |
+
self.test_data = test_data_scaled
|
| 307 |
+
|
| 308 |
+
self.results['data_scaling'] = self.data_scaler.get_report()
|
| 309 |
+
|
| 310 |
+
# Step 12: Feature selection
|
| 311 |
+
logger.info("\n" + "="*80)
|
| 312 |
+
logger.info("STEP 12: FEATURE SELECTION")
|
| 313 |
+
logger.info("="*80)
|
| 314 |
+
|
| 315 |
+
if len(train_data_scaled.columns) > 5:
|
| 316 |
+
# Select features on training data
|
| 317 |
+
train_data_selected = self.feature_selector.select(
|
| 318 |
+
train_data_scaled,
|
| 319 |
+
method=self.config.feature_selection_method,
|
| 320 |
+
n_features=min(self.config.max_features, len(train_data_scaled.columns) - 1)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Save selected features
|
| 324 |
+
selected_features = self.feature_selector.selected_features
|
| 325 |
+
|
| 326 |
+
# Apply same selection to validation and test data
|
| 327 |
+
features_to_keep = selected_features + [self.config.target_column]
|
| 328 |
+
features_to_keep = [f for f in features_to_keep if f in val_data_scaled.columns]
|
| 329 |
+
|
| 330 |
+
if len(features_to_keep) > 1:
|
| 331 |
+
self.train_data = train_data_scaled[features_to_keep].copy()
|
| 332 |
+
self.val_data = val_data_scaled[features_to_keep].copy()
|
| 333 |
+
self.test_data = test_data_scaled[features_to_keep].copy()
|
| 334 |
+
else:
|
| 335 |
+
logger.warning("Failed to select features, using all")
|
| 336 |
+
|
| 337 |
+
self.results['feature_selection'] = self.feature_selector.get_report()
|
| 338 |
+
else:
|
| 339 |
+
logger.info("Skipped: insufficient features for selection")
|
| 340 |
+
self.results['feature_selection'] = {'skipped': 'insufficient features'}
|
| 341 |
+
|
| 342 |
+
# Step 13: Final validation
|
| 343 |
+
logger.info("\n" + "="*80)
|
| 344 |
+
logger.info("STEP 13: FINAL VALIDATION")
|
| 345 |
+
logger.info("="*80)
|
| 346 |
+
|
| 347 |
+
# Combine all data for final validation
|
| 348 |
+
all_processed_data = pd.concat([self.train_data, self.val_data, self.test_data])
|
| 349 |
+
|
| 350 |
+
final_validation = self.data_validator.validate(
|
| 351 |
+
all_processed_data, stage='final', detailed=True
|
| 352 |
+
)
|
| 353 |
+
self.results['final_validation'] = final_validation
|
| 354 |
+
|
| 355 |
+
self.processed_data = all_processed_data
|
| 356 |
+
self.is_fitted = True
|
| 357 |
+
|
| 358 |
+
# Step 14: Additional multicollinearity cleaning
|
| 359 |
+
logger.info("\n" + "="*80)
|
| 360 |
+
logger.info("STEP 14: ADDITIONAL MULTICOLLINEARITY CLEANING")
|
| 361 |
+
logger.info("="*80)
|
| 362 |
+
|
| 363 |
+
# Remove temporal features with extreme VIF
|
| 364 |
+
self.processed_data = self._remove_extreme_vif_features(self.processed_data)
|
| 365 |
+
self.train_data = self.train_data[self.processed_data.columns]
|
| 366 |
+
self.val_data = self.val_data[self.processed_data.columns]
|
| 367 |
+
self.test_data = self.test_data[self.processed_data.columns]
|
| 368 |
+
|
| 369 |
+
# Step 15: Create visualisations and reports
|
| 370 |
+
logger.info("\n" + "="*80)
|
| 371 |
+
logger.info("STEP 15: CREATING REPORTS AND VISUALISATIONS")
|
| 372 |
+
logger.info("="*80)
|
| 373 |
+
|
| 374 |
+
if create_reports:
|
| 375 |
+
self.create_all_reports()
|
| 376 |
+
self.create_all_visualisations()
|
| 377 |
+
|
| 378 |
+
# Calculate execution time
|
| 379 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 380 |
+
|
| 381 |
+
# Save final results
|
| 382 |
+
self.results['pipeline_execution'] = {
|
| 383 |
+
'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
| 384 |
+
'end_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 385 |
+
'execution_time_seconds': execution_time,
|
| 386 |
+
'success': True,
|
| 387 |
+
'stages_completed': 15
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
# Save configuration and results
|
| 391 |
+
self.save_pipeline_results()
|
| 392 |
+
|
| 393 |
+
logger.info("\n" + "="*80)
|
| 394 |
+
logger.info("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
|
| 395 |
+
logger.info("="*80)
|
| 396 |
+
logger.info(f"Execution time: {execution_time:.2f} seconds")
|
| 397 |
+
logger.info(f"Initial data size: {self.results['data_loading']['shape']}")
|
| 398 |
+
logger.info(f"Final data size: {list(self.processed_data.shape)}")
|
| 399 |
+
logger.info(f"Data quality: {final_validation['overall_score']}/100")
|
| 400 |
+
logger.info(f"Status: {final_validation['status']}")
|
| 401 |
+
logger.info(f"Training data: {len(self.train_data)} records")
|
| 402 |
+
logger.info(f"Features in final set: {len(self.train_data.columns)}")
|
| 403 |
+
|
| 404 |
+
return self.processed_data
|
| 405 |
+
|
| 406 |
+
except Exception as e:
|
| 407 |
+
logger.error(f"✗ Pipeline error: {e}")
|
| 408 |
+
logger.error(traceback.format_exc())
|
| 409 |
+
|
| 410 |
+
self.results['pipeline_execution'] = {
|
| 411 |
+
'success': False,
|
| 412 |
+
'error': str(e),
|
| 413 |
+
'traceback': traceback.format_exc()
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
# Save partial results
|
| 417 |
+
self.save_pipeline_results()
|
| 418 |
+
|
| 419 |
+
return None
|
| 420 |
+
|
| 421 |
+
def _remove_extreme_vif_features(self, data: pd.DataFrame) -> pd.DataFrame:
|
| 422 |
+
"""Remove features with extreme VIF"""
|
| 423 |
+
data_clean = data.copy()
|
| 424 |
+
|
| 425 |
+
# Identify features with extreme VIF for removal
|
| 426 |
+
extreme_vif_features = [
|
| 427 |
+
'year', 'day', 'dayofyear', 'days_from_start',
|
| 428 |
+
'raskhodvoda_zscore' # Usually has extreme VIF
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
# Remove only those present in data
|
| 432 |
+
features_to_remove = [f for f in extreme_vif_features if f in data_clean.columns]
|
| 433 |
+
|
| 434 |
+
if features_to_remove:
|
| 435 |
+
logger.info(f"Removing features with extreme VIF: {features_to_remove}")
|
| 436 |
+
data_clean = data_clean.drop(columns=features_to_remove)
|
| 437 |
+
|
| 438 |
+
return data_clean
|
| 439 |
+
|
| 440 |
+
def create_all_reports(self) -> None:
|
| 441 |
+
"""Create all reports"""
|
| 442 |
+
logger.info("Creating reports...")
|
| 443 |
+
|
| 444 |
+
# 1. Save validation results
|
| 445 |
+
for stage in ['raw', 'processed', 'final']:
|
| 446 |
+
if stage in self.data_validator.validation_results:
|
| 447 |
+
self.data_validator.save_report(stage)
|
| 448 |
+
|
| 449 |
+
# 2. Save plots information
|
| 450 |
+
self.visualisation_manager.save_plots_info()
|
| 451 |
+
|
| 452 |
+
# 3. Create summary report
|
| 453 |
+
self.create_summary_report()
|
| 454 |
+
|
| 455 |
+
logger.info("✓ All reports created")
|
| 456 |
+
|
| 457 |
+
def create_all_visualisations(self) -> None:
|
| 458 |
+
"""Create all visualisations"""
|
| 459 |
+
logger.info("Creating visualisations...")
|
| 460 |
+
|
| 461 |
+
if self.processed_data is not None:
|
| 462 |
+
# 1. Summary dashboard
|
| 463 |
+
preprocessing_stages = {
|
| 464 |
+
'Loading': self.results['data_loading']['shape'][1] if 'data_loading' in self.results else 0,
|
| 465 |
+
'After cleaning': len(self.processed_data.columns),
|
| 466 |
+
'Features created': self.feature_engineer.feature_info.get('features_created', 0),
|
| 467 |
+
'Features selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
self.visualisation_manager.create_summary_dashboard(
|
| 471 |
+
self.processed_data,
|
| 472 |
+
preprocessing_stages
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
logger.info("✓ All visualisations created")
|
| 476 |
+
|
| 477 |
+
def create_summary_report(self) -> None:
|
| 478 |
+
"""Create summary report"""
|
| 479 |
+
report = {
|
| 480 |
+
'pipeline_summary': {
|
| 481 |
+
'config': self.config.to_dict(),
|
| 482 |
+
'execution': self.results.get('pipeline_execution', {}),
|
| 483 |
+
'data_statistics': {
|
| 484 |
+
'initial_shape': self.results.get('data_loading', {}).get('shape', []),
|
| 485 |
+
'final_shape': list(self.processed_data.shape) if self.processed_data is not None else [],
|
| 486 |
+
'target_column': self.config.target_column,
|
| 487 |
+
'features_created': self.feature_engineer.feature_info.get('features_created', 0),
|
| 488 |
+
'features_selected': len(self.feature_selector.selected_features) if hasattr(self.feature_selector, 'selected_features') else 0
|
| 489 |
+
}
|
| 490 |
+
},
|
| 491 |
+
'validation_summary': {},
|
| 492 |
+
'quality_metrics': {}
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# Add validation results
|
| 496 |
+
for stage in ['raw', 'processed', 'final']:
|
| 497 |
+
if stage in self.data_validator.validation_results:
|
| 498 |
+
stage_results = self.data_validator.validation_results[stage]
|
| 499 |
+
report['validation_summary'][stage] = {
|
| 500 |
+
'status': stage_results.get('status'),
|
| 501 |
+
'score': stage_results.get('overall_score'),
|
| 502 |
+
'issues_count': sum(len(issues) for issues in stage_results.get('issues', {}).values()),
|
| 503 |
+
'checks_passed': sum(1 for check in stage_results.get('basic_checks', {}).values()
|
| 504 |
+
if check.get('passed', False))
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
# Save report
|
| 508 |
+
report_path = f'{self.config.results_dir}/reports/pipeline_summary.json'
|
| 509 |
+
|
| 510 |
+
with open(report_path, 'w', encoding='utf-8') as f:
|
| 511 |
+
json.dump(report, f, indent=4, ensure_ascii=False)
|
| 512 |
+
|
| 513 |
+
logger.info(f"✓ Summary report saved: {report_path}")
|
| 514 |
+
|
| 515 |
+
def save_pipeline_results(self) -> None:
|
| 516 |
+
"""Save all pipeline results"""
|
| 517 |
+
# Save configuration
|
| 518 |
+
self.config.save()
|
| 519 |
+
|
| 520 |
+
# Save data
|
| 521 |
+
if self.processed_data is not None:
|
| 522 |
+
# Save processed data
|
| 523 |
+
data_path = f'{self.config.results_dir}/processed_data/processed_data.csv'
|
| 524 |
+
self.processed_data.to_csv(data_path)
|
| 525 |
+
logger.info(f"✓ Processed data saved: {data_path}")
|
| 526 |
+
|
| 527 |
+
# Save split data
|
| 528 |
+
if self.train_data is not None:
|
| 529 |
+
self.train_data.to_csv(f'{self.config.results_dir}/processed_data/train_data.csv')
|
| 530 |
+
self.val_data.to_csv(f'{self.config.results_dir}/processed_data/val_data.csv')
|
| 531 |
+
self.test_data.to_csv(f'{self.config.results_dir}/processed_data/test_data.csv')
|
| 532 |
+
|
| 533 |
+
def get_final_data_for_modelling(self) -> Dict[str, Any]:
|
| 534 |
+
"""Prepare data for modelling"""
|
| 535 |
+
if not self.is_fitted:
|
| 536 |
+
logger.warning("Pipeline not executed, data not ready")
|
| 537 |
+
return {}
|
| 538 |
+
|
| 539 |
+
return {
|
| 540 |
+
'X_train': self.train_data.drop(columns=[self.config.target_column]),
|
| 541 |
+
'y_train': self.train_data[self.config.target_column],
|
| 542 |
+
'X_val': self.val_data.drop(columns=[self.config.target_column]),
|
| 543 |
+
'y_val': self.val_data[self.config.target_column],
|
| 544 |
+
'X_test': self.test_data.drop(columns=[self.config.target_column]),
|
| 545 |
+
'y_test': self.test_data[self.config.target_column],
|
| 546 |
+
'feature_names': self.train_data.drop(columns=[self.config.target_column]).columns.tolist(),
|
| 547 |
+
'scaler': self.data_scaler,
|
| 548 |
+
'feature_selector': self.feature_selector,
|
| 549 |
+
'results': self.results
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# ============================================
|
| 554 |
+
# QUICK LAUNCH FUNCTION
|
| 555 |
+
# ============================================
|
| 556 |
+
def run_enhanced_preprocessing(
|
| 557 |
+
config_path: Optional[str] = None,
|
| 558 |
+
data_path: Optional[str] = None,
|
| 559 |
+
use_synthetic: bool = False,
|
| 560 |
+
save_results: bool = True
|
| 561 |
+
) -> EnhancedDataPreprocessingPipeline:
|
| 562 |
+
"""
|
| 563 |
+
Quick launch function for enhanced pipeline
|
| 564 |
+
|
| 565 |
+
Parameters:
|
| 566 |
+
-----------
|
| 567 |
+
config_path : str, optional
|
| 568 |
+
Path to configuration file
|
| 569 |
+
data_path : str, optional
|
| 570 |
+
Path to data
|
| 571 |
+
use_synthetic : bool
|
| 572 |
+
Use synthetic data
|
| 573 |
+
save_results : bool
|
| 574 |
+
Save results
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
--------
|
| 578 |
+
EnhancedDataPreprocessingPipeline
|
| 579 |
+
Pipeline object with results
|
| 580 |
+
"""
|
| 581 |
+
# Load or create configuration
|
| 582 |
+
if config_path and os.path.exists(config_path):
|
| 583 |
+
config = Config.load(config_path)
|
| 584 |
+
logger.info(f"Configuration loaded from {config_path}")
|
| 585 |
+
else:
|
| 586 |
+
config = Config()
|
| 587 |
+
logger.info("Using default configuration")
|
| 588 |
+
|
| 589 |
+
# Update data path if specified
|
| 590 |
+
if data_path:
|
| 591 |
+
config.data_path = data_path
|
| 592 |
+
|
| 593 |
+
# Create and run pipeline
|
| 594 |
+
pipeline = EnhancedDataPreprocessingPipeline(config)
|
| 595 |
+
|
| 596 |
+
pipeline.run_full_pipeline(
|
| 597 |
+
data_path=data_path,
|
| 598 |
+
use_synthetic=use_synthetic,
|
| 599 |
+
save_intermediate=save_results,
|
| 600 |
+
create_reports=save_results
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
return pipeline
|
requirements.txt
CHANGED
|
@@ -1,3 +1,100 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
altair==6.0.0
|
| 3 |
+
anyio==4.11.0
|
| 4 |
+
astunparse==1.6.3
|
| 5 |
+
attrs==25.4.0
|
| 6 |
+
blinker==1.9.0
|
| 7 |
+
cachetools==6.2.4
|
| 8 |
+
certifi==2025.11.12
|
| 9 |
+
charset-normalizer==3.4.4
|
| 10 |
+
click==8.3.1
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.3.2
|
| 13 |
+
cycler==0.12.1
|
| 14 |
+
et_xmlfile==2.0.0
|
| 15 |
+
filelock==3.20.1
|
| 16 |
+
flatbuffers==25.12.19
|
| 17 |
+
fonttools==4.61.1
|
| 18 |
+
fsspec==2025.12.0
|
| 19 |
+
gast==0.7.0
|
| 20 |
+
gensim==4.4.0
|
| 21 |
+
gitdb==4.0.12
|
| 22 |
+
GitPython==3.1.46
|
| 23 |
+
google-pasta==0.2.0
|
| 24 |
+
grpcio==1.76.0
|
| 25 |
+
h5py==3.15.1
|
| 26 |
+
h11==0.16.0
|
| 27 |
+
hf-xet==1.2.0
|
| 28 |
+
httpcore==1.0.9
|
| 29 |
+
httpx==0.28.1
|
| 30 |
+
huggingface_hub==1.1.2
|
| 31 |
+
idna==3.11
|
| 32 |
+
Jinja2==3.1.6
|
| 33 |
+
joblib==1.5.3
|
| 34 |
+
jsonschema==4.25.1
|
| 35 |
+
jsonschema-specifications==2025.9.1
|
| 36 |
+
keras==3.13.0
|
| 37 |
+
kiwisolver==1.4.9
|
| 38 |
+
libclang==18.1.1
|
| 39 |
+
Markdown==3.10
|
| 40 |
+
markdown-it-py==4.0.0
|
| 41 |
+
MarkupSafe==3.0.3
|
| 42 |
+
matplotlib==3.10.8
|
| 43 |
+
mdurl==0.1.2
|
| 44 |
+
ml_dtypes==0.5.4
|
| 45 |
+
mpmath==1.3.0
|
| 46 |
+
namex==0.1.0
|
| 47 |
+
narwhals==2.14.0
|
| 48 |
+
networkx==3.6.1
|
| 49 |
+
numpy==2.4.0
|
| 50 |
+
openpyxl==3.1.5
|
| 51 |
+
opt_einsum==3.4.0
|
| 52 |
+
optree==0.18.0
|
| 53 |
+
packaging==25.0
|
| 54 |
+
pandas==2.3.3
|
| 55 |
+
patsy==1.0.2
|
| 56 |
+
pillow==12.0.0
|
| 57 |
+
plotly==6.5.0
|
| 58 |
+
protobuf==6.33.2
|
| 59 |
+
pyarrow==22.0.0
|
| 60 |
+
pydeck==0.9.1
|
| 61 |
+
Pygments==2.19.2
|
| 62 |
+
pyparsing==3.3.1
|
| 63 |
+
pyperclip==1.11.0
|
| 64 |
+
python-dateutil==2.9.0.post0
|
| 65 |
+
pytz==2025.2
|
| 66 |
+
PyYAML==6.0.3
|
| 67 |
+
referencing==0.37.0
|
| 68 |
+
requests==2.32.5
|
| 69 |
+
rich==14.2.0
|
| 70 |
+
rpds-py==0.30.0
|
| 71 |
+
scikit-learn==1.8.0
|
| 72 |
+
scipy==1.16.3
|
| 73 |
+
seaborn==0.13.2
|
| 74 |
+
setuptools==80.9.0
|
| 75 |
+
shellingham==1.5.4
|
| 76 |
+
six==1.17.0
|
| 77 |
+
smart_open==7.5.0
|
| 78 |
+
smmap==5.0.2
|
| 79 |
+
sniffio==1.3.1
|
| 80 |
+
statsmodels==0.14.6
|
| 81 |
+
streamlit==1.52.2
|
| 82 |
+
sympy==1.14.0
|
| 83 |
+
tenacity==9.1.2
|
| 84 |
+
tensorboard==2.20.0
|
| 85 |
+
tensorboard-data-server==0.7.2
|
| 86 |
+
tensorflow==2.20.0
|
| 87 |
+
termcolor==3.3.0
|
| 88 |
+
threadpoolctl==3.6.0
|
| 89 |
+
toml==0.10.2
|
| 90 |
+
torch==2.9.1
|
| 91 |
+
tornado==6.5.4
|
| 92 |
+
tqdm==4.67.1
|
| 93 |
+
typer-slim==0.20.0
|
| 94 |
+
typing_extensions==4.15.0
|
| 95 |
+
tzdata==2025.3
|
| 96 |
+
urllib3==2.6.2
|
| 97 |
+
watchdog==6.0.0
|
| 98 |
+
Werkzeug==3.1.4
|
| 99 |
+
wheel==0.45.1
|
| 100 |
+
wrapt==2.0.1
|
run_pipeline.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# RUN
|
| 3 |
+
# ============================================
|
| 4 |
+
from config.config import Config
|
| 5 |
+
from pipeline.main_pipeline import EnhancedDataPreprocessingPipeline
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
"""
|
| 10 |
+
Pipeline execution
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
# Configuration with reasonable parameters
|
| 14 |
+
config = Config(
|
| 15 |
+
data_path='temp_data.csv',
|
| 16 |
+
results_dir='enhanced_preprocessing_results',
|
| 17 |
+
target_column='raskhodvoda',
|
| 18 |
+
start_year=1970,
|
| 19 |
+
end_year=1990,
|
| 20 |
+
max_lags=5,
|
| 21 |
+
seasonal_period=365,
|
| 22 |
+
rolling_windows=[7, 30, 90],
|
| 23 |
+
expanding_windows=[30, 90],
|
| 24 |
+
test_size=0.2,
|
| 25 |
+
validation_size=0.1,
|
| 26 |
+
scaling_method='robust',
|
| 27 |
+
feature_selection_method='correlation',
|
| 28 |
+
max_features=20,
|
| 29 |
+
missing_threshold=0.3,
|
| 30 |
+
outlier_method='iqr',
|
| 31 |
+
enable_validation=True
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Run enhanced pipeline
|
| 35 |
+
pipeline = EnhancedDataPreprocessingPipeline(config)
|
| 36 |
+
processed_data = pipeline.run_full_pipeline(
|
| 37 |
+
use_synthetic=False,
|
| 38 |
+
save_intermediate=True,
|
| 39 |
+
create_reports=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if processed_data is not None:
|
| 43 |
+
print("\n" + "="*80)
|
| 44 |
+
print("ENHANCED PIPELINE SUCCESSFULLY COMPLETED!")
|
| 45 |
+
print("="*80)
|
| 46 |
+
print(f"Final data size: {processed_data.shape}")
|
| 47 |
+
print(f"Columns: {list(processed_data.columns)}")
|
| 48 |
+
|
| 49 |
+
# Get modeling data
|
| 50 |
+
modeling_data = pipeline.get_final_data_for_modeling()
|
| 51 |
+
|
| 52 |
+
if modeling_data:
|
| 53 |
+
print(f"\nModeling data ready:")
|
| 54 |
+
print(f" X_train: {modeling_data['X_train'].shape}")
|
| 55 |
+
print(f" X_val: {modeling_data['X_val'].shape}")
|
| 56 |
+
print(f" X_test: {modeling_data['X_test'].shape}")
|
| 57 |
+
print(f" Features: {len(modeling_data['feature_names'])}")
|
| 58 |
+
|
| 59 |
+
# Save final data
|
| 60 |
+
processed_data.to_csv('enhanced_preprocessing_results\processed_data\enhanced_final_processed_data.csv',
|
| 61 |
+
index=True if isinstance(processed_data.index, pd.DatetimeIndex) else False)
|
| 62 |
+
print(f"\n✓ Final data saved to 'enhanced_final_processed_data.csv'")
|
scaling/__init__.py
ADDED
|
File without changes
|
scaling/data_scaler.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 10: DATA SCALING
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
from venv import logger
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from config.config import Config
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DataScaler:
|
| 13 |
+
"""Class for data scaling and normalisation"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: Config):
|
| 16 |
+
"""
|
| 17 |
+
Initialise scaler
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
-----------
|
| 21 |
+
config : Config
|
| 22 |
+
Experiment configuration
|
| 23 |
+
"""
|
| 24 |
+
self.config = config
|
| 25 |
+
self.scalers = {}
|
| 26 |
+
self.scaling_info = {}
|
| 27 |
+
self.transforms_applied = {}
|
| 28 |
+
|
| 29 |
+
def fit_transform(
|
| 30 |
+
self,
|
| 31 |
+
data: pd.DataFrame,
|
| 32 |
+
method: str = None,
|
| 33 |
+
columns: List[str] = None,
|
| 34 |
+
target_col: Optional[str] = None,
|
| 35 |
+
fit_on_train: bool = True,
|
| 36 |
+
**kwargs
|
| 37 |
+
) -> pd.DataFrame:
|
| 38 |
+
"""
|
| 39 |
+
Scale data
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
-----------
|
| 43 |
+
data : pd.DataFrame
|
| 44 |
+
Input data
|
| 45 |
+
method : str, optional
|
| 46 |
+
Scaling method. If None, uses configuration value.
|
| 47 |
+
columns : List[str], optional
|
| 48 |
+
List of columns to scale. If None, uses all numeric columns.
|
| 49 |
+
target_col : str, optional
|
| 50 |
+
Target variable (not scaled by default)
|
| 51 |
+
fit_on_train : bool
|
| 52 |
+
Whether to save scaling parameters for applying to new data
|
| 53 |
+
**kwargs : dict
|
| 54 |
+
Additional parameters for method
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
--------
|
| 58 |
+
pd.DataFrame
|
| 59 |
+
Scaled data
|
| 60 |
+
"""
|
| 61 |
+
logger.info("\n" + "="*80)
|
| 62 |
+
logger.info("DATA SCALING")
|
| 63 |
+
logger.info("="*80)
|
| 64 |
+
|
| 65 |
+
method = method or self.config.scaling_method
|
| 66 |
+
data_scaled = data.copy()
|
| 67 |
+
|
| 68 |
+
if columns is None:
|
| 69 |
+
# Select all numeric columns except target
|
| 70 |
+
numeric_cols = data_scaled.select_dtypes(include=[np.number]).columns
|
| 71 |
+
if target_col and target_col in numeric_cols:
|
| 72 |
+
columns = [col for col in numeric_cols if col != target_col]
|
| 73 |
+
else:
|
| 74 |
+
columns = list(numeric_cols)
|
| 75 |
+
|
| 76 |
+
logger.info(f"Scaling method: {method}")
|
| 77 |
+
logger.info(f"Columns to process: {len(columns)}")
|
| 78 |
+
|
| 79 |
+
# Apply scaling
|
| 80 |
+
for col in columns:
|
| 81 |
+
if col in data_scaled.columns:
|
| 82 |
+
try:
|
| 83 |
+
# Check feature type
|
| 84 |
+
series = data_scaled[col].dropna()
|
| 85 |
+
|
| 86 |
+
# Special handling for different feature types
|
| 87 |
+
if self._is_binary_feature(series):
|
| 88 |
+
logger.debug(f" {col}: binary feature, scaling not applied")
|
| 89 |
+
scaler_info = {
|
| 90 |
+
'method': 'none',
|
| 91 |
+
'scaler_type': 'binary',
|
| 92 |
+
'original_values': sorted(series.unique().tolist()),
|
| 93 |
+
'note': 'binary feature, no scaling applied'
|
| 94 |
+
}
|
| 95 |
+
self.scaling_info[col] = scaler_info
|
| 96 |
+
|
| 97 |
+
if fit_on_train:
|
| 98 |
+
self.scalers[col] = scaler_info
|
| 99 |
+
|
| 100 |
+
elif self._is_categorical_feature(series):
|
| 101 |
+
logger.debug(f" {col}: categorical feature, using min-max")
|
| 102 |
+
scaled_series, scaler_info = self._apply_scaling(
|
| 103 |
+
data_scaled[col], 'minmax', fit_on_train, **kwargs
|
| 104 |
+
)
|
| 105 |
+
data_scaled[col] = scaled_series
|
| 106 |
+
self.scaling_info[col] = scaler_info
|
| 107 |
+
|
| 108 |
+
if fit_on_train:
|
| 109 |
+
if scaler_info.get('scaler_type') == 'sklearn':
|
| 110 |
+
self.scalers[col] = scaler_info['scaler_object']
|
| 111 |
+
else:
|
| 112 |
+
self.scalers[col] = scaler_info
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
# Regular scaling for continuous features
|
| 116 |
+
scaled_series, scaler_info = self._apply_scaling(
|
| 117 |
+
data_scaled[col], method, fit_on_train, **kwargs
|
| 118 |
+
)
|
| 119 |
+
data_scaled[col] = scaled_series
|
| 120 |
+
self.scaling_info[col] = scaler_info
|
| 121 |
+
|
| 122 |
+
if fit_on_train:
|
| 123 |
+
if scaler_info.get('scaler_type') == 'sklearn':
|
| 124 |
+
self.scalers[col] = scaler_info['scaler_object']
|
| 125 |
+
else:
|
| 126 |
+
self.scalers[col] = scaler_info
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.warning(f"Error processing column {col}: {e}")
|
| 130 |
+
# Save error information
|
| 131 |
+
self.scaling_info[col] = {
|
| 132 |
+
'method': 'error',
|
| 133 |
+
'error': str(e),
|
| 134 |
+
'scaler_type': 'none'
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
logger.info(f"✓ Data processed using {method} method")
|
| 138 |
+
|
| 139 |
+
# Visualisation of results
|
| 140 |
+
if self.config.save_plots and columns:
|
| 141 |
+
self._plot_scaling_results(data, data_scaled, columns, method)
|
| 142 |
+
|
| 143 |
+
return data_scaled
|
| 144 |
+
|
| 145 |
+
def _is_binary_feature(self, series: pd.Series) -> bool:
|
| 146 |
+
"""Check if feature is binary"""
|
| 147 |
+
unique_values = series.dropna().unique()
|
| 148 |
+
return len(unique_values) == 2 and set(unique_values).issubset({0, 1})
|
| 149 |
+
|
| 150 |
+
def _is_categorical_feature(self, series: pd.Series, max_categories: int = 10) -> bool:
|
| 151 |
+
"""Check if feature is categorical"""
|
| 152 |
+
unique_values = series.dropna().unique()
|
| 153 |
+
return len(unique_values) <= max_categories and series.dtype in ['int64', 'float64']
|
| 154 |
+
|
| 155 |
+
def _apply_scaling(
|
| 156 |
+
self,
|
| 157 |
+
series: pd.Series,
|
| 158 |
+
method: str,
|
| 159 |
+
fit_on_train: bool,
|
| 160 |
+
**kwargs
|
| 161 |
+
) -> Tuple[pd.Series, Dict]:
|
| 162 |
+
"""Apply specific scaling method"""
|
| 163 |
+
series_clean = series.dropna()
|
| 164 |
+
|
| 165 |
+
if len(series_clean) == 0:
|
| 166 |
+
return series, {
|
| 167 |
+
'method': 'none',
|
| 168 |
+
'scaler_type': 'none',
|
| 169 |
+
'error': 'all values are NaN'
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
scaler_info = {
|
| 173 |
+
'method': method,
|
| 174 |
+
'scaler_type': 'simple',
|
| 175 |
+
'original_mean': float(series_clean.mean()),
|
| 176 |
+
'original_std': float(series_clean.std()),
|
| 177 |
+
'original_min': float(series_clean.min()),
|
| 178 |
+
'original_max': float(series_clean.max()),
|
| 179 |
+
'scaler': None,
|
| 180 |
+
'scaler_object': None
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
if method == 'standard':
|
| 185 |
+
# Standardisation (z-score normalisation)
|
| 186 |
+
mean = series_clean.mean()
|
| 187 |
+
std = series_clean.std()
|
| 188 |
+
|
| 189 |
+
if std > 0:
|
| 190 |
+
series_scaled = (series - mean) / std
|
| 191 |
+
scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
|
| 192 |
+
else:
|
| 193 |
+
series_scaled = series - mean # If std = 0, just center
|
| 194 |
+
scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
|
| 195 |
+
|
| 196 |
+
elif method == 'minmax':
|
| 197 |
+
# Min-Max normalisation
|
| 198 |
+
min_val = series_clean.min()
|
| 199 |
+
max_val = series_clean.max()
|
| 200 |
+
|
| 201 |
+
if max_val > min_val:
|
| 202 |
+
series_scaled = (series - min_val) / (max_val - min_val)
|
| 203 |
+
scaler_info['scaler'] = {'min': float(min_val), 'max': float(max_val)}
|
| 204 |
+
else:
|
| 205 |
+
series_scaled = series - min_val # If all values equal
|
| 206 |
+
scaler_info['scaler'] = {'min': float(min_val), 'max': float(min_val)}
|
| 207 |
+
|
| 208 |
+
elif method == 'robust':
|
| 209 |
+
# Robust scaling (outlier resistant)
|
| 210 |
+
# Check sufficient values for quartile calculation
|
| 211 |
+
if len(series_clean) >= 4:
|
| 212 |
+
median = series_clean.median()
|
| 213 |
+
q1 = series_clean.quantile(0.25)
|
| 214 |
+
q3 = series_clean.quantile(0.75)
|
| 215 |
+
iqr = q3 - q1
|
| 216 |
+
|
| 217 |
+
if iqr > 0:
|
| 218 |
+
series_scaled = (series - median) / iqr
|
| 219 |
+
scaler_info['scaler'] = {
|
| 220 |
+
'median': float(median),
|
| 221 |
+
'q1': float(q1),
|
| 222 |
+
'q3': float(q3),
|
| 223 |
+
'iqr': float(iqr)
|
| 224 |
+
}
|
| 225 |
+
else:
|
| 226 |
+
# If IQR = 0, use standard deviation
|
| 227 |
+
std = series_clean.std()
|
| 228 |
+
if std > 0:
|
| 229 |
+
series_scaled = (series - median) / std
|
| 230 |
+
scaler_info['scaler'] = {'median': float(median), 'std': float(std)}
|
| 231 |
+
else:
|
| 232 |
+
series_scaled = series - median
|
| 233 |
+
scaler_info['scaler'] = {'median': float(median), 'iqr': 0}
|
| 234 |
+
else:
|
| 235 |
+
# If insufficient data, use standardisation
|
| 236 |
+
mean = series_clean.mean()
|
| 237 |
+
std = series_clean.std()
|
| 238 |
+
if std > 0:
|
| 239 |
+
series_scaled = (series - mean) / std
|
| 240 |
+
scaler_info['scaler'] = {'mean': float(mean), 'std': float(std)}
|
| 241 |
+
scaler_info['method'] = 'standard' # Change method in info
|
| 242 |
+
else:
|
| 243 |
+
series_scaled = series - mean
|
| 244 |
+
scaler_info['scaler'] = {'mean': float(mean), 'std': 0}
|
| 245 |
+
scaler_info['method'] = 'standard'
|
| 246 |
+
|
| 247 |
+
elif method == 'log':
|
| 248 |
+
# Logarithmic transformation
|
| 249 |
+
min_val = series_clean.min()
|
| 250 |
+
|
| 251 |
+
if min_val <= 0:
|
| 252 |
+
shift = abs(min_val) + 1
|
| 253 |
+
series_scaled = np.log(series + shift)
|
| 254 |
+
scaler_info['scaler'] = {'shift': float(shift)}
|
| 255 |
+
else:
|
| 256 |
+
series_scaled = np.log(series)
|
| 257 |
+
scaler_info['scaler'] = {'shift': 0}
|
| 258 |
+
|
| 259 |
+
elif method == 'boxcox':
|
| 260 |
+
# Box-Cox transformation
|
| 261 |
+
try:
|
| 262 |
+
from scipy.stats import boxcox
|
| 263 |
+
|
| 264 |
+
min_val = series_clean.min()
|
| 265 |
+
if min_val <= 0:
|
| 266 |
+
shift = abs(min_val) + 1
|
| 267 |
+
series_to_transform = series + shift
|
| 268 |
+
else:
|
| 269 |
+
shift = 0
|
| 270 |
+
series_to_transform = series
|
| 271 |
+
|
| 272 |
+
transformed, lambda_val = boxcox(series_to_transform.dropna())
|
| 273 |
+
|
| 274 |
+
# Interpolate for all values
|
| 275 |
+
series_scaled = series.copy()
|
| 276 |
+
valid_mask = series_to_transform.notna()
|
| 277 |
+
series_scaled[valid_mask] = transformed
|
| 278 |
+
|
| 279 |
+
scaler_info['scaler'] = {
|
| 280 |
+
'lambda': float(lambda_val),
|
| 281 |
+
'shift': float(shift)
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.warning(f"Box-Cox transformation failed for {series.name}: {e}")
|
| 286 |
+
# Return original series and change method
|
| 287 |
+
series_scaled = series
|
| 288 |
+
scaler_info['method'] = 'none'
|
| 289 |
+
scaler_info['scaler_type'] = 'none'
|
| 290 |
+
scaler_info['error'] = str(e)
|
| 291 |
+
|
| 292 |
+
elif method == 'quantile':
|
| 293 |
+
# Quantile transformation (rank-based)
|
| 294 |
+
try:
|
| 295 |
+
from sklearn.preprocessing import QuantileTransformer
|
| 296 |
+
|
| 297 |
+
qt = QuantileTransformer(
|
| 298 |
+
n_quantiles=kwargs.get('n_quantiles', min(100, len(series_clean))),
|
| 299 |
+
output_distribution=kwargs.get('output_distribution', 'normal'),
|
| 300 |
+
random_state=kwargs.get('random_state', 42)
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
series_reshaped = series.values.reshape(-1, 1)
|
| 304 |
+
series_scaled_values = qt.fit_transform(series_reshaped)
|
| 305 |
+
series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
|
| 306 |
+
|
| 307 |
+
scaler_info['scaler_type'] = 'sklearn'
|
| 308 |
+
scaler_info['scaler_object'] = qt
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.warning(f"Quantile transform failed for {series.name}: {e}")
|
| 312 |
+
series_scaled = series
|
| 313 |
+
scaler_info['method'] = 'none'
|
| 314 |
+
scaler_info['scaler_type'] = 'none'
|
| 315 |
+
scaler_info['error'] = str(e)
|
| 316 |
+
|
| 317 |
+
elif method == 'power':
|
| 318 |
+
# Power transform (Yeo-Johnson)
|
| 319 |
+
try:
|
| 320 |
+
from sklearn.preprocessing import PowerTransformer
|
| 321 |
+
|
| 322 |
+
pt = PowerTransformer(method='yeo-johnson', standardize=True)
|
| 323 |
+
|
| 324 |
+
series_reshaped = series.values.reshape(-1, 1)
|
| 325 |
+
series_scaled_values = pt.fit_transform(series_reshaped)
|
| 326 |
+
series_scaled = pd.Series(series_scaled_values.flatten(), index=series.index)
|
| 327 |
+
|
| 328 |
+
scaler_info['scaler_type'] = 'sklearn'
|
| 329 |
+
scaler_info['scaler_object'] = pt
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
logger.warning(f"Power transform failed for {series.name}: {e}")
|
| 333 |
+
series_scaled = series
|
| 334 |
+
scaler_info['method'] = 'none'
|
| 335 |
+
scaler_info['scaler_type'] = 'none'
|
| 336 |
+
scaler_info['error'] = str(e)
|
| 337 |
+
|
| 338 |
+
elif method == 'none':
|
| 339 |
+
# No scaling
|
| 340 |
+
series_scaled = series
|
| 341 |
+
scaler_info['method'] = 'none'
|
| 342 |
+
scaler_info['scaler_type'] = 'none'
|
| 343 |
+
|
| 344 |
+
else:
|
| 345 |
+
logger.warning(f"Unknown scaling method: {method}, using standardisation")
|
| 346 |
+
return self._apply_scaling(series, 'standard', fit_on_train, **kwargs)
|
| 347 |
+
|
| 348 |
+
# Add statistics after scaling
|
| 349 |
+
scaled_clean = series_scaled.dropna()
|
| 350 |
+
if len(scaled_clean) > 0:
|
| 351 |
+
scaler_info.update({
|
| 352 |
+
'scaled_mean': float(scaled_clean.mean()),
|
| 353 |
+
'scaled_std': float(scaled_clean.std()),
|
| 354 |
+
'scaled_min': float(scaled_clean.min()),
|
| 355 |
+
'scaled_max': float(scaled_clean.max()),
|
| 356 |
+
'skewness_before': float(series_clean.skew()),
|
| 357 |
+
'skewness_after': float(scaled_clean.skew()),
|
| 358 |
+
'kurtosis_before': float(series_clean.kurtosis()),
|
| 359 |
+
'kurtosis_after': float(scaled_clean.kurtosis())
|
| 360 |
+
})
|
| 361 |
+
|
| 362 |
+
return series_scaled, scaler_info
|
| 363 |
+
|
| 364 |
+
except Exception as e:
|
| 365 |
+
logger.warning(f"Error applying method {method} for {series.name}: {e}")
|
| 366 |
+
return series, {
|
| 367 |
+
'method': 'error',
|
| 368 |
+
'scaler_type': 'none',
|
| 369 |
+
'error': str(e)
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
def transform(
|
| 373 |
+
self,
|
| 374 |
+
data: pd.DataFrame,
|
| 375 |
+
columns: List[str] = None
|
| 376 |
+
) -> pd.DataFrame:
|
| 377 |
+
"""
|
| 378 |
+
Apply saved scaling to new data
|
| 379 |
+
|
| 380 |
+
Parameters:
|
| 381 |
+
-----------
|
| 382 |
+
data : pd.DataFrame
|
| 383 |
+
New data
|
| 384 |
+
columns : List[str], optional
|
| 385 |
+
List of columns to transform
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
--------
|
| 389 |
+
pd.DataFrame
|
| 390 |
+
Transformed data
|
| 391 |
+
"""
|
| 392 |
+
if not self.scalers:
|
| 393 |
+
logger.warning("Scalers not trained, use fit_transform first")
|
| 394 |
+
return data
|
| 395 |
+
|
| 396 |
+
data_transformed = data.copy()
|
| 397 |
+
|
| 398 |
+
if columns is None:
|
| 399 |
+
columns = [col for col in self.scalers.keys() if col in data_transformed.columns]
|
| 400 |
+
|
| 401 |
+
for col in columns:
|
| 402 |
+
if col in data_transformed.columns and col in self.scalers:
|
| 403 |
+
try:
|
| 404 |
+
scaler_info = self.scaling_info.get(col, {})
|
| 405 |
+
scaler_data = self.scalers[col]
|
| 406 |
+
method = scaler_info.get('method', 'unknown')
|
| 407 |
+
|
| 408 |
+
# For binary features, do nothing
|
| 409 |
+
if method == 'none' and scaler_info.get('scaler_type') == 'binary':
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
# Skip errors
|
| 413 |
+
if method == 'error':
|
| 414 |
+
continue
|
| 415 |
+
|
| 416 |
+
if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
|
| 417 |
+
scaler_params = scaler_data['scaler']
|
| 418 |
+
|
| 419 |
+
if method == 'standard':
|
| 420 |
+
mean = scaler_params.get('mean', 0)
|
| 421 |
+
std = scaler_params.get('std', 1)
|
| 422 |
+
if std > 0:
|
| 423 |
+
data_transformed[col] = (data_transformed[col] - mean) / std
|
| 424 |
+
|
| 425 |
+
elif method == 'minmax':
|
| 426 |
+
min_val = scaler_params.get('min', 0)
|
| 427 |
+
max_val = scaler_params.get('max', 1)
|
| 428 |
+
if max_val > min_val:
|
| 429 |
+
data_transformed[col] = (data_transformed[col] - min_val) / (max_val - min_val)
|
| 430 |
+
|
| 431 |
+
elif method == 'robust':
|
| 432 |
+
median = scaler_params.get('median', 0)
|
| 433 |
+
iqr = scaler_params.get('iqr', 1)
|
| 434 |
+
if iqr > 0:
|
| 435 |
+
data_transformed[col] = (data_transformed[col] - median) / iqr
|
| 436 |
+
else:
|
| 437 |
+
std = scaler_params.get('std', 1)
|
| 438 |
+
if std > 0:
|
| 439 |
+
data_transformed[col] = (data_transformed[col] - median) / std
|
| 440 |
+
|
| 441 |
+
elif hasattr(scaler_data, 'transform'):
|
| 442 |
+
# For sklearn objects
|
| 443 |
+
from sklearn.base import BaseEstimator
|
| 444 |
+
if isinstance(scaler_data, BaseEstimator):
|
| 445 |
+
try:
|
| 446 |
+
transformed = scaler_data.transform(
|
| 447 |
+
data_transformed[[col]].values.reshape(-1, 1)
|
| 448 |
+
).flatten()
|
| 449 |
+
data_transformed[col] = transformed
|
| 450 |
+
except Exception as e:
|
| 451 |
+
logger.warning(f"Error in sklearn transformation for {col}: {e}")
|
| 452 |
+
|
| 453 |
+
except Exception as e:
|
| 454 |
+
logger.warning(f"Error transforming column {col}: {e}")
|
| 455 |
+
|
| 456 |
+
return data_transformed
|
| 457 |
+
|
| 458 |
+
def inverse_transform(
|
| 459 |
+
self,
|
| 460 |
+
data: pd.DataFrame,
|
| 461 |
+
columns: List[str] = None
|
| 462 |
+
) -> pd.DataFrame:
|
| 463 |
+
"""
|
| 464 |
+
Inverse transform scaled data
|
| 465 |
+
|
| 466 |
+
Parameters:
|
| 467 |
+
-----------
|
| 468 |
+
data : pd.DataFrame
|
| 469 |
+
Scaled data
|
| 470 |
+
columns : List[str], optional
|
| 471 |
+
List of columns for inverse transform
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
--------
|
| 475 |
+
pd.DataFrame
|
| 476 |
+
Data in original scale
|
| 477 |
+
"""
|
| 478 |
+
if not self.scalers:
|
| 479 |
+
logger.warning("Scalers not trained")
|
| 480 |
+
return data
|
| 481 |
+
|
| 482 |
+
data_inverse = data.copy()
|
| 483 |
+
|
| 484 |
+
if columns is None:
|
| 485 |
+
columns = [col for col in self.scalers.keys() if col in data_inverse.columns]
|
| 486 |
+
|
| 487 |
+
for col in columns:
|
| 488 |
+
if col in data_inverse.columns and col in self.scalers:
|
| 489 |
+
try:
|
| 490 |
+
scaler_info = self.scaling_info.get(col, {})
|
| 491 |
+
scaler_data = self.scalers[col]
|
| 492 |
+
method = scaler_info.get('method', 'unknown')
|
| 493 |
+
|
| 494 |
+
# For binary and categorical features, do nothing
|
| 495 |
+
if method in ['none', 'error']:
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
if isinstance(scaler_data, dict) and 'scaler' in scaler_data:
|
| 499 |
+
scaler_params = scaler_data['scaler']
|
| 500 |
+
|
| 501 |
+
if method == 'standard':
|
| 502 |
+
mean = scaler_params.get('mean', 0)
|
| 503 |
+
std = scaler_params.get('std', 1)
|
| 504 |
+
data_inverse[col] = data_inverse[col] * std + mean
|
| 505 |
+
|
| 506 |
+
elif method == 'minmax':
|
| 507 |
+
min_val = scaler_params.get('min', 0)
|
| 508 |
+
max_val = scaler_params.get('max', 1)
|
| 509 |
+
if max_val > min_val:
|
| 510 |
+
data_inverse[col] = data_inverse[col] * (max_val - min_val) + min_val
|
| 511 |
+
|
| 512 |
+
elif method == 'robust':
|
| 513 |
+
median = scaler_params.get('median', 0)
|
| 514 |
+
iqr = scaler_params.get('iqr', 1)
|
| 515 |
+
if iqr > 0:
|
| 516 |
+
data_inverse[col] = data_inverse[col] * iqr + median
|
| 517 |
+
else:
|
| 518 |
+
std = scaler_params.get('std', 1)
|
| 519 |
+
if std > 0:
|
| 520 |
+
data_inverse[col] = data_inverse[col] * std + median
|
| 521 |
+
|
| 522 |
+
elif hasattr(scaler_data, 'inverse_transform'):
|
| 523 |
+
# For sklearn objects
|
| 524 |
+
from sklearn.base import BaseEstimator
|
| 525 |
+
if isinstance(scaler_data, BaseEstimator):
|
| 526 |
+
try:
|
| 527 |
+
inverse_transformed = scaler_data.inverse_transform(
|
| 528 |
+
data_inverse[[col]].values.reshape(-1, 1)
|
| 529 |
+
).flatten()
|
| 530 |
+
data_inverse[col] = inverse_transformed
|
| 531 |
+
except Exception as e:
|
| 532 |
+
logger.warning(f"Error in sklearn inverse transformation for {col}: {e}")
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.warning(f"Error in inverse transformation for column {col}: {e}")
|
| 536 |
+
|
| 537 |
+
return data_inverse
|
| 538 |
+
|
| 539 |
+
def _plot_scaling_results(
|
| 540 |
+
self,
|
| 541 |
+
original_data: pd.DataFrame,
|
| 542 |
+
scaled_data: pd.DataFrame,
|
| 543 |
+
columns: List[str],
|
| 544 |
+
method: str
|
| 545 |
+
) -> None:
|
| 546 |
+
"""Visualise scaling results"""
|
| 547 |
+
# Limit number of columns for visualisation
|
| 548 |
+
cols_to_plot = [col for col in columns if col in original_data.columns and col in scaled_data.columns][:8]
|
| 549 |
+
|
| 550 |
+
if not cols_to_plot:
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
n_cols = 4
|
| 554 |
+
n_rows = (len(cols_to_plot) + n_cols - 1) // n_cols
|
| 555 |
+
|
| 556 |
+
fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(16, 4 * n_rows))
|
| 557 |
+
|
| 558 |
+
for idx, col in enumerate(cols_to_plot):
|
| 559 |
+
row = idx // n_cols
|
| 560 |
+
col_idx = (idx % n_cols) * 2
|
| 561 |
+
|
| 562 |
+
# Distribution before scaling
|
| 563 |
+
axes[row, col_idx].hist(
|
| 564 |
+
original_data[col].dropna(),
|
| 565 |
+
bins=30,
|
| 566 |
+
alpha=0.7,
|
| 567 |
+
color='blue',
|
| 568 |
+
density=True
|
| 569 |
+
)
|
| 570 |
+
axes[row, col_idx].set_title(f'{col} (before)', fontsize=10)
|
| 571 |
+
axes[row, col_idx].set_xlabel('Value')
|
| 572 |
+
axes[row, col_idx].set_ylabel('Density')
|
| 573 |
+
axes[row, col_idx].grid(True, alpha=0.3)
|
| 574 |
+
|
| 575 |
+
# Distribution after scaling
|
| 576 |
+
axes[row, col_idx + 1].hist(
|
| 577 |
+
scaled_data[col].dropna(),
|
| 578 |
+
bins=30,
|
| 579 |
+
alpha=0.7,
|
| 580 |
+
color='green',
|
| 581 |
+
density=True
|
| 582 |
+
)
|
| 583 |
+
axes[row, col_idx + 1].set_title(f'{col} (after)', fontsize=10)
|
| 584 |
+
axes[row, col_idx + 1].set_xlabel('Scaled value')
|
| 585 |
+
axes[row, col_idx + 1].set_ylabel('Density')
|
| 586 |
+
axes[row, col_idx + 1].grid(True, alpha=0.3)
|
| 587 |
+
|
| 588 |
+
# Hide unused subplots
|
| 589 |
+
total_plots = n_rows * n_cols * 2
|
| 590 |
+
for idx in range(len(cols_to_plot) * 2, total_plots):
|
| 591 |
+
row = idx // (n_cols * 2)
|
| 592 |
+
col_idx = idx % (n_cols * 2)
|
| 593 |
+
axes[row, col_idx].set_visible(False)
|
| 594 |
+
|
| 595 |
+
plt.suptitle(f'Scaling results using {method} method', fontsize=14)
|
| 596 |
+
plt.tight_layout()
|
| 597 |
+
plt.savefig(
|
| 598 |
+
f'{self.config.results_dir}/plots/scaling_results.png',
|
| 599 |
+
dpi=300,
|
| 600 |
+
bbox_inches='tight'
|
| 601 |
+
)
|
| 602 |
+
plt.show()
|
| 603 |
+
|
| 604 |
+
def get_report(self) -> Dict:
|
| 605 |
+
"""Get scaling report"""
|
| 606 |
+
summary = {
|
| 607 |
+
'total_columns': len(self.scaling_info),
|
| 608 |
+
'methods_used': {},
|
| 609 |
+
'binary_features': [],
|
| 610 |
+
'categorical_features': [],
|
| 611 |
+
'continuous_features': [],
|
| 612 |
+
'errors': []
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
for col, info in self.scaling_info.items():
|
| 616 |
+
method = info.get('method', 'unknown')
|
| 617 |
+
if method not in summary['methods_used']:
|
| 618 |
+
summary['methods_used'][method] = 0
|
| 619 |
+
summary['methods_used'][method] += 1
|
| 620 |
+
|
| 621 |
+
if method == 'none' and info.get('scaler_type') == 'binary':
|
| 622 |
+
summary['binary_features'].append(col)
|
| 623 |
+
elif method in ['minmax', 'standard', 'robust']:
|
| 624 |
+
summary['continuous_features'].append(col)
|
| 625 |
+
elif method == 'error':
|
| 626 |
+
summary['errors'].append({
|
| 627 |
+
'column': col,
|
| 628 |
+
'error': info.get('error', 'unknown')
|
| 629 |
+
})
|
| 630 |
+
|
| 631 |
+
return {
|
| 632 |
+
'summary': summary,
|
| 633 |
+
'details': self.scaling_info
|
| 634 |
+
}
|
splitting/__init__.py
ADDED
|
File without changes
|
splitting/data_splitter.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 9: DATA SPLITTING
|
| 3 |
+
# ============================================
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Dict, Optional, Tuple
|
| 6 |
+
from venv import logger
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from config.config import Config
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DataSplitter:
|
| 14 |
+
"""Class for splitting data into train, validation and test sets"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config: Config):
|
| 17 |
+
"""
|
| 18 |
+
Initialise data splitter
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
-----------
|
| 22 |
+
config : Config
|
| 23 |
+
Experiment configuration
|
| 24 |
+
"""
|
| 25 |
+
self.config = config
|
| 26 |
+
self.split_info = {}
|
| 27 |
+
self.split_indices = {}
|
| 28 |
+
self.split_strategy = None
|
| 29 |
+
|
| 30 |
+
def split(
|
| 31 |
+
self,
|
| 32 |
+
data: pd.DataFrame,
|
| 33 |
+
test_size: Optional[float] = None,
|
| 34 |
+
validation_size: Optional[float] = None,
|
| 35 |
+
method: str = None,
|
| 36 |
+
random_state: int = 42,
|
| 37 |
+
**kwargs
|
| 38 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 39 |
+
"""
|
| 40 |
+
Split data into train, validation and test sets
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
-----------
|
| 44 |
+
data : pd.DataFrame
|
| 45 |
+
Input data
|
| 46 |
+
test_size : float, optional
|
| 47 |
+
Test set size. If None, uses configuration value.
|
| 48 |
+
validation_size : float, optional
|
| 49 |
+
Validation set size. If None, uses configuration value.
|
| 50 |
+
method : str, optional
|
| 51 |
+
Splitting method: 'time', 'random', 'expanding_window', 'sliding_window'
|
| 52 |
+
random_state : int
|
| 53 |
+
Seed for reproducibility
|
| 54 |
+
**kwargs : dict
|
| 55 |
+
Additional parameters for method
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
--------
|
| 59 |
+
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
| 60 |
+
Train, validation and test data
|
| 61 |
+
"""
|
| 62 |
+
logger.info("\n" + "="*80)
|
| 63 |
+
logger.info("DATA SPLITTING")
|
| 64 |
+
logger.info("="*80)
|
| 65 |
+
|
| 66 |
+
test_size = test_size or self.config.test_size
|
| 67 |
+
validation_size = validation_size or self.config.validation_size
|
| 68 |
+
method = method or self.config.split_method
|
| 69 |
+
|
| 70 |
+
n = len(data)
|
| 71 |
+
|
| 72 |
+
logger.info(f"Total data: {n} records")
|
| 73 |
+
logger.info(f"Splitting method: {method}")
|
| 74 |
+
logger.info(f"Sizes: train={1-test_size-validation_size:.1%}, val={validation_size:.1%}, test={test_size:.1%}")
|
| 75 |
+
|
| 76 |
+
if method == 'time':
|
| 77 |
+
train_data, val_data, test_data = self._time_based_split(
|
| 78 |
+
data, test_size, validation_size
|
| 79 |
+
)
|
| 80 |
+
elif method == 'random':
|
| 81 |
+
train_data, val_data, test_data = self._random_split(
|
| 82 |
+
data, test_size, validation_size, random_state
|
| 83 |
+
)
|
| 84 |
+
elif method == 'expanding_window':
|
| 85 |
+
train_data, val_data, test_data = self._expanding_window_split(
|
| 86 |
+
data, test_size, validation_size, **kwargs
|
| 87 |
+
)
|
| 88 |
+
elif method == 'sliding_window':
|
| 89 |
+
train_data, val_data, test_data = self._sliding_window_split(
|
| 90 |
+
data, **kwargs
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
logger.warning(f"Method {method} not supported, using time-based split")
|
| 94 |
+
train_data, val_data, test_data = self._time_based_split(
|
| 95 |
+
data, test_size, validation_size
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Save splitting information
|
| 99 |
+
self._save_split_info(data, train_data, val_data, test_data, method)
|
| 100 |
+
|
| 101 |
+
# Output information
|
| 102 |
+
self._log_split_summary(train_data, val_data, test_data)
|
| 103 |
+
|
| 104 |
+
# Visualisation of split
|
| 105 |
+
if self.config.save_plots:
|
| 106 |
+
self._plot_data_split(data, train_data, val_data, test_data)
|
| 107 |
+
|
| 108 |
+
return train_data, val_data, test_data
|
| 109 |
+
|
| 110 |
+
def _time_based_split(
|
| 111 |
+
self,
|
| 112 |
+
data: pd.DataFrame,
|
| 113 |
+
test_size: float,
|
| 114 |
+
validation_size: float
|
| 115 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 116 |
+
"""Time-based splitting preserving temporal order"""
|
| 117 |
+
n = len(data)
|
| 118 |
+
|
| 119 |
+
# Calculate set sizes
|
| 120 |
+
test_size_int = int(n * test_size)
|
| 121 |
+
val_size_int = int(n * validation_size)
|
| 122 |
+
train_size_int = n - test_size_int - val_size_int
|
| 123 |
+
|
| 124 |
+
# Split data
|
| 125 |
+
train_data = data.iloc[:train_size_int].copy()
|
| 126 |
+
val_data = data.iloc[train_size_int:train_size_int + val_size_int].copy()
|
| 127 |
+
test_data = data.iloc[train_size_int + val_size_int:].copy()
|
| 128 |
+
|
| 129 |
+
self.split_strategy = 'time_based'
|
| 130 |
+
|
| 131 |
+
return train_data, val_data, test_data
|
| 132 |
+
|
| 133 |
+
def _random_split(
|
| 134 |
+
self,
|
| 135 |
+
data: pd.DataFrame,
|
| 136 |
+
test_size: float,
|
| 137 |
+
validation_size: float,
|
| 138 |
+
random_state: int
|
| 139 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 140 |
+
"""Random data splitting"""
|
| 141 |
+
from sklearn.model_selection import train_test_split
|
| 142 |
+
|
| 143 |
+
# First split into train+val and test
|
| 144 |
+
train_val_data, test_data = train_test_split(
|
| 145 |
+
data,
|
| 146 |
+
test_size=test_size,
|
| 147 |
+
random_state=random_state,
|
| 148 |
+
shuffle=True
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Then split train+val into train and val
|
| 152 |
+
val_relative_size = validation_size / (1 - test_size)
|
| 153 |
+
train_data, val_data = train_test_split(
|
| 154 |
+
train_val_data,
|
| 155 |
+
test_size=val_relative_size,
|
| 156 |
+
random_state=random_state,
|
| 157 |
+
shuffle=True
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.split_strategy = 'random'
|
| 161 |
+
|
| 162 |
+
return train_data, val_data, test_data
|
| 163 |
+
|
| 164 |
+
def _expanding_window_split(
|
| 165 |
+
self,
|
| 166 |
+
data: pd.DataFrame,
|
| 167 |
+
test_size: float,
|
| 168 |
+
validation_size: float,
|
| 169 |
+
**kwargs
|
| 170 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 171 |
+
"""Expanding window split"""
|
| 172 |
+
n = len(data)
|
| 173 |
+
|
| 174 |
+
# Minimum initial window size
|
| 175 |
+
initial_window = kwargs.get('initial_window', max(100, int(n * 0.1)))
|
| 176 |
+
|
| 177 |
+
# Final set sizes
|
| 178 |
+
test_size_int = int(n * test_size)
|
| 179 |
+
val_size_int = int(n * validation_size)
|
| 180 |
+
|
| 181 |
+
# Determine boundaries
|
| 182 |
+
test_start = n - test_size_int
|
| 183 |
+
val_start = test_start - val_size_int
|
| 184 |
+
|
| 185 |
+
# For expanding window, use all data up to val_start for training
|
| 186 |
+
train_data = data.iloc[:val_start].copy()
|
| 187 |
+
val_data = data.iloc[val_start:test_start].copy()
|
| 188 |
+
test_data = data.iloc[test_start:].copy()
|
| 189 |
+
|
| 190 |
+
self.split_strategy = 'expanding_window'
|
| 191 |
+
|
| 192 |
+
return train_data, val_data, test_data
|
| 193 |
+
|
| 194 |
+
def _sliding_window_split(
|
| 195 |
+
self,
|
| 196 |
+
data: pd.DataFrame,
|
| 197 |
+
**kwargs
|
| 198 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 199 |
+
"""Sliding window split (for multiple train-val-test pairs)"""
|
| 200 |
+
window_size = kwargs.get('window_size', len(data) // 3)
|
| 201 |
+
step = kwargs.get('step', window_size // 2)
|
| 202 |
+
|
| 203 |
+
# For simplicity return single split
|
| 204 |
+
# In real scenarios can return list of splits
|
| 205 |
+
n = len(data)
|
| 206 |
+
|
| 207 |
+
train_end = n - window_size
|
| 208 |
+
val_end = train_end + window_size // 3
|
| 209 |
+
test_end = n
|
| 210 |
+
|
| 211 |
+
train_data = data.iloc[:train_end].copy()
|
| 212 |
+
val_data = data.iloc[train_end:val_end].copy()
|
| 213 |
+
test_data = data.iloc[val_end:].copy()
|
| 214 |
+
|
| 215 |
+
self.split_strategy = 'sliding_window'
|
| 216 |
+
|
| 217 |
+
return train_data, val_data, test_data
|
| 218 |
+
|
| 219 |
+
def _save_split_info(
|
| 220 |
+
self,
|
| 221 |
+
full_data: pd.DataFrame,
|
| 222 |
+
train_data: pd.DataFrame,
|
| 223 |
+
val_data: pd.DataFrame,
|
| 224 |
+
test_data: pd.DataFrame,
|
| 225 |
+
method: str
|
| 226 |
+
) -> None:
|
| 227 |
+
"""Save splitting information"""
|
| 228 |
+
n = len(full_data)
|
| 229 |
+
|
| 230 |
+
self.split_info = {
|
| 231 |
+
'method': method,
|
| 232 |
+
'strategy': self.split_strategy,
|
| 233 |
+
'train_size': len(train_data),
|
| 234 |
+
'val_size': len(val_data),
|
| 235 |
+
'test_size': len(test_data),
|
| 236 |
+
'train_percent': len(train_data) / n * 100,
|
| 237 |
+
'val_percent': len(val_data) / n * 100,
|
| 238 |
+
'test_percent': len(test_data) / n * 100,
|
| 239 |
+
'total_samples': n,
|
| 240 |
+
'features_count': len(full_data.columns),
|
| 241 |
+
'split_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
# Add temporal period information if available
|
| 245 |
+
if isinstance(full_data.index, pd.DatetimeIndex):
|
| 246 |
+
self.split_info.update({
|
| 247 |
+
'train_period': {
|
| 248 |
+
'start': train_data.index.min().strftime('%Y-%m-%d'),
|
| 249 |
+
'end': train_data.index.max().strftime('%Y-%m-%d')
|
| 250 |
+
},
|
| 251 |
+
'val_period': {
|
| 252 |
+
'start': val_data.index.min().strftime('%Y-%m-%d'),
|
| 253 |
+
'end': val_data.index.max().strftime('%Y-%m-%d')
|
| 254 |
+
},
|
| 255 |
+
'test_period': {
|
| 256 |
+
'start': test_data.index.min().strftime('%Y-%m-%d'),
|
| 257 |
+
'end': test_data.index.max().strftime('%Y-%m-%d')
|
| 258 |
+
}
|
| 259 |
+
})
|
| 260 |
+
|
| 261 |
+
# Save split indices
|
| 262 |
+
self.split_indices = {
|
| 263 |
+
'train': train_data.index.tolist(),
|
| 264 |
+
'val': val_data.index.tolist(),
|
| 265 |
+
'test': test_data.index.tolist()
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
def _log_split_summary(
|
| 269 |
+
self,
|
| 270 |
+
train_data: pd.DataFrame,
|
| 271 |
+
val_data: pd.DataFrame,
|
| 272 |
+
test_data: pd.DataFrame
|
| 273 |
+
) -> None:
|
| 274 |
+
"""Log splitting summary"""
|
| 275 |
+
logger.info("✓ Data split completed:")
|
| 276 |
+
logger.info(f" Train: {len(train_data)} records ({self.split_info['train_percent']:.1f}%)")
|
| 277 |
+
logger.info(f" Validation: {len(val_data)} records ({self.split_info['val_percent']:.1f}%)")
|
| 278 |
+
logger.info(f" Test: {len(test_data)} records ({self.split_info['test_percent']:.1f}%)")
|
| 279 |
+
|
| 280 |
+
if 'train_period' in self.split_info:
|
| 281 |
+
logger.info(f"\nPeriods:")
|
| 282 |
+
logger.info(f" Train: {self.split_info['train_period']['start']} - {self.split_info['train_period']['end']}")
|
| 283 |
+
logger.info(f" Validation: {self.split_info['val_period']['start']} - {self.split_info['val_period']['end']}")
|
| 284 |
+
logger.info(f" Test: {self.split_info['test_period']['start']} - {self.split_info['test_period']['end']}")
|
| 285 |
+
|
| 286 |
+
# Target variable statistics
|
| 287 |
+
target = self.config.target_column
|
| 288 |
+
if target in train_data.columns:
|
| 289 |
+
logger.info(f"\nTarget variable '{target}' statistics:")
|
| 290 |
+
logger.info(f" Train: mean={train_data[target].mean():.2f}, std={train_data[target].std():.2f}")
|
| 291 |
+
logger.info(f" Validation: mean={val_data[target].mean():.2f}, std={val_data[target].std():.2f}")
|
| 292 |
+
logger.info(f" Test: mean={test_data[target].mean():.2f}, std={test_data[target].std():.2f}")
|
| 293 |
+
|
| 294 |
+
def _plot_data_split(
|
| 295 |
+
self,
|
| 296 |
+
full_data: pd.DataFrame,
|
| 297 |
+
train_data: pd.DataFrame,
|
| 298 |
+
val_data: pd.DataFrame,
|
| 299 |
+
test_data: pd.DataFrame
|
| 300 |
+
) -> None:
|
| 301 |
+
"""Visualise data splitting"""
|
| 302 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 303 |
+
|
| 304 |
+
target = self.config.target_column
|
| 305 |
+
|
| 306 |
+
# 1. Time series with set highlighting
|
| 307 |
+
if target in full_data.columns and isinstance(full_data.index, pd.DatetimeIndex):
|
| 308 |
+
axes[0, 0].plot(train_data.index, train_data[target],
|
| 309 |
+
label='Train', colour='blue', alpha=0.7, linewidth=1)
|
| 310 |
+
axes[0, 0].plot(val_data.index, val_data[target],
|
| 311 |
+
label='Validation', colour='orange', alpha=0.7, linewidth=1)
|
| 312 |
+
axes[0, 0].plot(test_data.index, test_data[target],
|
| 313 |
+
label='Test', colour='red', alpha=0.7, linewidth=1)
|
| 314 |
+
|
| 315 |
+
axes[0, 0].set_title(f'Data Split: {target}')
|
| 316 |
+
axes[0, 0].set_xlabel('Date')
|
| 317 |
+
axes[0, 0].set_ylabel(target)
|
| 318 |
+
axes[0, 0].legend()
|
| 319 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 320 |
+
|
| 321 |
+
# 2. Yearly distribution
|
| 322 |
+
if isinstance(full_data.index, pd.DatetimeIndex):
|
| 323 |
+
full_data['year'] = full_data.index.year
|
| 324 |
+
train_data['year'] = train_data.index.year
|
| 325 |
+
val_data['year'] = val_data.index.year
|
| 326 |
+
test_data['year'] = test_data.index.year
|
| 327 |
+
|
| 328 |
+
years = sorted(full_data['year'].unique())
|
| 329 |
+
train_counts = [len(train_data[train_data['year'] == year]) for year in years]
|
| 330 |
+
val_counts = [len(val_data[val_data['year'] == year]) for year in years]
|
| 331 |
+
test_counts = [len(test_data[test_data['year'] == year]) for year in years]
|
| 332 |
+
|
| 333 |
+
x = np.arange(len(years))
|
| 334 |
+
width = 0.25
|
| 335 |
+
|
| 336 |
+
axes[0, 1].bar(x - width, train_counts, width, label='Train', colour='blue', alpha=0.7)
|
| 337 |
+
axes[0, 1].bar(x, val_counts, width, label='Validation', colour='orange', alpha=0.7)
|
| 338 |
+
axes[0, 1].bar(x + width, test_counts, width, label='Test', colour='red', alpha=0.7)
|
| 339 |
+
|
| 340 |
+
axes[0, 1].set_title('Yearly Data Distribution')
|
| 341 |
+
axes[0, 1].set_xlabel('Year')
|
| 342 |
+
axes[0, 1].set_ylabel('Number of Records')
|
| 343 |
+
axes[0, 1].set_xticks(x)
|
| 344 |
+
axes[0, 1].set_xticklabels(years, rotation=45)
|
| 345 |
+
axes[0, 1].legend()
|
| 346 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 347 |
+
|
| 348 |
+
# Remove added columns
|
| 349 |
+
for df in [full_data, train_data, val_data, test_data]:
|
| 350 |
+
if 'year' in df.columns:
|
| 351 |
+
df.drop('year', axis=1, inplace=True)
|
| 352 |
+
|
| 353 |
+
# 3. Target variable distribution
|
| 354 |
+
if target in full_data.columns:
|
| 355 |
+
axes[1, 0].hist(train_data[target].dropna(), bins=30, alpha=0.5, label='Train', density=True)
|
| 356 |
+
axes[1, 0].hist(val_data[target].dropna(), bins=30, alpha=0.5, label='Validation', density=True)
|
| 357 |
+
axes[1, 0].hist(test_data[target].dropna(), bins=30, alpha=0.5, label='Test', density=True)
|
| 358 |
+
|
| 359 |
+
axes[1, 0].set_title(f'{target} Distribution Across Sets')
|
| 360 |
+
axes[1, 0].set_xlabel(target)
|
| 361 |
+
axes[1, 0].set_ylabel('Density')
|
| 362 |
+
axes[1, 0].legend()
|
| 363 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 364 |
+
|
| 365 |
+
# 4. Set statistics
|
| 366 |
+
if target in full_data.columns:
|
| 367 |
+
stats_data = []
|
| 368 |
+
for name, df in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
|
| 369 |
+
if target in df.columns:
|
| 370 |
+
stats_data.append({
|
| 371 |
+
'Dataset': name,
|
| 372 |
+
'Mean': df[target].mean(),
|
| 373 |
+
'Std': df[target].std(),
|
| 374 |
+
'Min': df[target].min(),
|
| 375 |
+
'Max': df[target].max()
|
| 376 |
+
})
|
| 377 |
+
|
| 378 |
+
if stats_data:
|
| 379 |
+
stats_df = pd.DataFrame(stats_data)
|
| 380 |
+
stats_table = axes[1, 1].table(
|
| 381 |
+
cellText=stats_df.round(2).values,
|
| 382 |
+
colLabels=stats_df.columns,
|
| 383 |
+
cellLoc='center',
|
| 384 |
+
loc='center'
|
| 385 |
+
)
|
| 386 |
+
stats_table.auto_set_font_size(False)
|
| 387 |
+
stats_table.set_fontsize(9)
|
| 388 |
+
stats_table.scale(1, 1.5)
|
| 389 |
+
axes[1, 1].axis('off')
|
| 390 |
+
axes[1, 1].set_title('Set Statistics')
|
| 391 |
+
|
| 392 |
+
plt.suptitle(f'Data Splitting: {self.split_info["method"]} method', fontsize=14)
|
| 393 |
+
plt.tight_layout()
|
| 394 |
+
plt.savefig(
|
| 395 |
+
f'{self.config.results_dir}/plots/data_split.png',
|
| 396 |
+
dpi=300,
|
| 397 |
+
bbox_inches='tight'
|
| 398 |
+
)
|
| 399 |
+
plt.show()
|
| 400 |
+
|
| 401 |
+
def get_report(self) -> Dict:
|
| 402 |
+
"""Get data splitting report"""
|
| 403 |
+
return self.split_info
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stationarity/__init__.py
ADDED
|
File without changes
|
stationarity/stationarity_checker.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 6: STATIONARITY ANALYSIS
|
| 3 |
+
# ============================================
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
from venv import logger
|
| 6 |
+
from config.config import Config
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
|
| 11 |
+
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
|
| 12 |
+
|
| 13 |
+
class StationarityChecker:
|
| 14 |
+
"""Class for checking time series stationarity"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config: Config):
|
| 17 |
+
"""
|
| 18 |
+
Initialise stationarity checker
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
-----------
|
| 22 |
+
config : Config
|
| 23 |
+
Experiment configuration
|
| 24 |
+
"""
|
| 25 |
+
self.config = config
|
| 26 |
+
self.test_results = {}
|
| 27 |
+
self.transformed_series = {}
|
| 28 |
+
self.best_transformation = {}
|
| 29 |
+
|
| 30 |
+
def check(
|
| 31 |
+
self,
|
| 32 |
+
data: pd.DataFrame,
|
| 33 |
+
target_col: Optional[str] = None,
|
| 34 |
+
make_stationary: bool = True,
|
| 35 |
+
try_transformations: bool = True
|
| 36 |
+
) -> Dict:
|
| 37 |
+
"""
|
| 38 |
+
Check time series stationarity
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
-----------
|
| 42 |
+
data : pd.DataFrame
|
| 43 |
+
Input data
|
| 44 |
+
target_col : str, optional
|
| 45 |
+
Target variable. If None, uses configuration default.
|
| 46 |
+
make_stationary : bool
|
| 47 |
+
Transform series to stationary form
|
| 48 |
+
try_transformations : bool
|
| 49 |
+
Try various transformations to achieve stationarity
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
--------
|
| 53 |
+
Dict
|
| 54 |
+
Stationarity test results
|
| 55 |
+
"""
|
| 56 |
+
logger.info("\n" + "="*80)
|
| 57 |
+
logger.info("STATIONARITY ANALYSIS")
|
| 58 |
+
logger.info("="*80)
|
| 59 |
+
|
| 60 |
+
target_col = target_col or self.config.target_column
|
| 61 |
+
|
| 62 |
+
if target_col not in data.columns:
|
| 63 |
+
logger.error(f"Target variable '{target_col}' not found")
|
| 64 |
+
return {}
|
| 65 |
+
|
| 66 |
+
series = data[target_col].dropna()
|
| 67 |
+
|
| 68 |
+
if len(series) < 10:
|
| 69 |
+
logger.warning("Insufficient data for stationarity analysis")
|
| 70 |
+
return {}
|
| 71 |
+
|
| 72 |
+
# Perform analysis
|
| 73 |
+
results = self._perform_stationarity_tests(series, target_col)
|
| 74 |
+
|
| 75 |
+
# Save results
|
| 76 |
+
self.test_results[target_col] = results
|
| 77 |
+
|
| 78 |
+
# Visualisation
|
| 79 |
+
if self.config.save_plots:
|
| 80 |
+
self._plot_stationarity_analysis(data, target_col, results)
|
| 81 |
+
|
| 82 |
+
# Log results
|
| 83 |
+
self._log_test_results(target_col, results)
|
| 84 |
+
|
| 85 |
+
# Transform to stationary form
|
| 86 |
+
if make_stationary and not results['overall']['is_stationary']:
|
| 87 |
+
if try_transformations:
|
| 88 |
+
transformed_data = self._make_stationary(data, target_col, results)
|
| 89 |
+
if transformed_data is not None:
|
| 90 |
+
data = transformed_data
|
| 91 |
+
|
| 92 |
+
return results
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _perform_stationarity_tests(
|
| 96 |
+
self,
|
| 97 |
+
series: pd.Series,
|
| 98 |
+
target_col: str
|
| 99 |
+
) -> Dict:
|
| 100 |
+
"""Perform various stationarity tests"""
|
| 101 |
+
results = {
|
| 102 |
+
'adf': self._adf_test(series),
|
| 103 |
+
'kpss': self._kpss_test(series),
|
| 104 |
+
'pp': self._pp_test(series),
|
| 105 |
+
'hurst': self._hurst_exponent(series),
|
| 106 |
+
'variance_ratio': self._variance_ratio_test(series),
|
| 107 |
+
'overall': {}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Determine overall stationarity
|
| 111 |
+
adf_stationary = results['adf'].get('is_stationary', False)
|
| 112 |
+
kpss_stationary = results['kpss'].get('is_stationary', False)
|
| 113 |
+
pp_stationary = results['pp'].get('is_stationary', False)
|
| 114 |
+
|
| 115 |
+
# Stationarity determination logic
|
| 116 |
+
if adf_stationary and kpss_stationary:
|
| 117 |
+
overall_stationary = True
|
| 118 |
+
confidence = 'high'
|
| 119 |
+
elif adf_stationary and not kpss_stationary:
|
| 120 |
+
overall_stationary = True # ADF more reliable for detecting stationarity
|
| 121 |
+
confidence = 'medium'
|
| 122 |
+
elif not adf_stationary and kpss_stationary:
|
| 123 |
+
overall_stationary = False # KPSS indicates non-stationarity
|
| 124 |
+
confidence = 'medium'
|
| 125 |
+
else:
|
| 126 |
+
overall_stationary = False
|
| 127 |
+
confidence = 'high'
|
| 128 |
+
|
| 129 |
+
results['overall'] = {
|
| 130 |
+
'is_stationary': overall_stationary,
|
| 131 |
+
'confidence': confidence,
|
| 132 |
+
'recommendation': self._get_stationarity_recommendation(results)
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
return results
|
| 136 |
+
|
| 137 |
+
def _adf_test(self, series: pd.Series) -> Dict:
|
| 138 |
+
"""Augmented Dickey-Fuller (ADF) test"""
|
| 139 |
+
try:
|
| 140 |
+
adf_result = adfuller(series, autolag='AIC')
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
'statistic': float(adf_result[0]),
|
| 144 |
+
'pvalue': float(adf_result[1]),
|
| 145 |
+
'critical_values': {k: float(v) for k, v in adf_result[4].items()},
|
| 146 |
+
'is_stationary': adf_result[1] < 0.05,
|
| 147 |
+
'used_lag': int(adf_result[2]),
|
| 148 |
+
'nobs': int(adf_result[3])
|
| 149 |
+
}
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.warning(f"ADF test failed: {e}")
|
| 152 |
+
return {
|
| 153 |
+
'statistic': np.nan,
|
| 154 |
+
'pvalue': np.nan,
|
| 155 |
+
'critical_values': {},
|
| 156 |
+
'is_stationary': False,
|
| 157 |
+
'error': str(e)
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
def _kpss_test(self, series: pd.Series) -> Dict:
|
| 161 |
+
"""KPSS test"""
|
| 162 |
+
try:
|
| 163 |
+
kpss_result = kpss(series, regression='c', nlags='auto')
|
| 164 |
+
|
| 165 |
+
return {
|
| 166 |
+
'statistic': float(kpss_result[0]),
|
| 167 |
+
'pvalue': float(kpss_result[1]),
|
| 168 |
+
'critical_values': {k: float(v) for k, v in kpss_result[3].items()},
|
| 169 |
+
'is_stationary': kpss_result[1] > 0.05, # KPSS: p > 0.05 indicates stationarity
|
| 170 |
+
'used_lag': int(kpss_result[2])
|
| 171 |
+
}
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.warning(f"KPSS test failed: {e}")
|
| 174 |
+
return {
|
| 175 |
+
'statistic': np.nan,
|
| 176 |
+
'pvalue': np.nan,
|
| 177 |
+
'critical_values': {},
|
| 178 |
+
'is_stationary': False,
|
| 179 |
+
'error': str(e)
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
def _pp_test(self, series: pd.Series) -> Dict:
|
| 183 |
+
"""Phillips-Perron test"""
|
| 184 |
+
try:
|
| 185 |
+
# Simplified PP test version
|
| 186 |
+
from statsmodels.tsa.stattools import PhillipsPerron
|
| 187 |
+
|
| 188 |
+
pp_result = PhillipsPerron(series)
|
| 189 |
+
|
| 190 |
+
return {
|
| 191 |
+
'statistic': float(pp_result.stat),
|
| 192 |
+
'pvalue': float(pp_result.pvalue),
|
| 193 |
+
'critical_values': pp_result.critical_values,
|
| 194 |
+
'is_stationary': pp_result.pvalue < 0.05
|
| 195 |
+
}
|
| 196 |
+
except:
|
| 197 |
+
# If statsmodels with PP test not available
|
| 198 |
+
return {
|
| 199 |
+
'statistic': np.nan,
|
| 200 |
+
'pvalue': np.nan,
|
| 201 |
+
'critical_values': {},
|
| 202 |
+
'is_stationary': False,
|
| 203 |
+
'note': 'Phillips-Perron test not available'
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def _hurst_exponent(self, series: pd.Series) -> Dict:
|
| 207 |
+
"""Calculate Hurst exponent"""
|
| 208 |
+
try:
|
| 209 |
+
# Simplified Hurst exponent calculation
|
| 210 |
+
lags = range(2, min(100, len(series)//4))
|
| 211 |
+
tau = []
|
| 212 |
+
|
| 213 |
+
for lag in lags:
|
| 214 |
+
# Split series into subsequences of length lag
|
| 215 |
+
n = len(series) // lag
|
| 216 |
+
if n < 2:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
subseries = [series[i*lag:(i+1)*lag] for i in range(n)]
|
| 220 |
+
# Calculate R/S for each subsequence
|
| 221 |
+
rs_values = []
|
| 222 |
+
for sub in subseries:
|
| 223 |
+
if len(sub) > 1:
|
| 224 |
+
mean = np.mean(sub)
|
| 225 |
+
deviations = sub - mean
|
| 226 |
+
z = np.cumsum(deviations)
|
| 227 |
+
r = np.max(z) - np.min(z)
|
| 228 |
+
s = np.std(sub)
|
| 229 |
+
if s > 0:
|
| 230 |
+
rs_values.append(r / s)
|
| 231 |
+
|
| 232 |
+
if rs_values:
|
| 233 |
+
tau.append(np.mean(rs_values))
|
| 234 |
+
|
| 235 |
+
if len(tau) > 2:
|
| 236 |
+
# Linear regression in log coordinates
|
| 237 |
+
x = np.log(lags[:len(tau)])
|
| 238 |
+
y = np.log(tau)
|
| 239 |
+
|
| 240 |
+
if len(x) > 1 and len(y) > 1:
|
| 241 |
+
slope = np.polyfit(x, y, 1)[0]
|
| 242 |
+
|
| 243 |
+
# Hurst exponent interpretation
|
| 244 |
+
if slope > 0.5:
|
| 245 |
+
trend_type = 'persistent'
|
| 246 |
+
elif slope < 0.5:
|
| 247 |
+
trend_type = 'anti-persistent'
|
| 248 |
+
else:
|
| 249 |
+
trend_type = 'random'
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
'exponent': float(slope),
|
| 253 |
+
'trend_type': trend_type,
|
| 254 |
+
'interpretation': self._interpret_hurst(slope)
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
'exponent': np.nan,
|
| 259 |
+
'trend_type': 'unknown',
|
| 260 |
+
'interpretation': 'Insufficient data'
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
logger.debug(f"Hurst exponent not calculated: {e}")
|
| 265 |
+
return {
|
| 266 |
+
'exponent': np.nan,
|
| 267 |
+
'trend_type': 'unknown',
|
| 268 |
+
'error': str(e)
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def _interpret_hurst(self, hurst_exponent: float) -> str:
|
| 272 |
+
"""Interpret Hurst exponent"""
|
| 273 |
+
if hurst_exponent > 0.75:
|
| 274 |
+
return "Strong persistence (long-term memory)"
|
| 275 |
+
elif hurst_exponent > 0.6:
|
| 276 |
+
return "Moderate persistence"
|
| 277 |
+
elif hurst_exponent > 0.4:
|
| 278 |
+
return "Weak persistence / random walk"
|
| 279 |
+
elif hurst_exponent > 0.25:
|
| 280 |
+
return "Weak anti-persistence"
|
| 281 |
+
else:
|
| 282 |
+
return "Strong anti-persistence (frequent trend reversal)"
|
| 283 |
+
|
| 284 |
+
def _variance_ratio_test(self, series: pd.Series) -> Dict:
|
| 285 |
+
"""Variance Ratio test for random walk"""
|
| 286 |
+
try:
|
| 287 |
+
# Simplified variance ratio test
|
| 288 |
+
if len(series) < 20:
|
| 289 |
+
return {'ratio': np.nan, 'is_random_walk': False}
|
| 290 |
+
|
| 291 |
+
# Calculate differences
|
| 292 |
+
diff1 = series.diff(1).dropna()
|
| 293 |
+
diff2 = series.diff(2).dropna()[1:] # Shift to align indices
|
| 294 |
+
|
| 295 |
+
if len(diff1) < 5 or len(diff2) < 5:
|
| 296 |
+
return {'ratio': np.nan, 'is_random_walk': False}
|
| 297 |
+
|
| 298 |
+
var1 = np.var(diff1)
|
| 299 |
+
var2 = np.var(diff2)
|
| 300 |
+
|
| 301 |
+
if var1 > 0:
|
| 302 |
+
ratio = var2 / (2 * var1)
|
| 303 |
+
|
| 304 |
+
# For random walk ratio ≈ 1
|
| 305 |
+
is_random_walk = 0.8 < ratio < 1.2
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
'ratio': float(ratio),
|
| 309 |
+
'is_random_walk': bool(is_random_walk),
|
| 310 |
+
'var_diff1': float(var1),
|
| 311 |
+
'var_diff2': float(var2)
|
| 312 |
+
}
|
| 313 |
+
else:
|
| 314 |
+
return {'ratio': np.nan, 'is_random_walk': False}
|
| 315 |
+
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.debug(f"Variance ratio test failed: {e}")
|
| 318 |
+
return {'ratio': np.nan, 'is_random_walk': False, 'error': str(e)}
|
| 319 |
+
|
| 320 |
+
def _get_stationarity_recommendation(self, results: Dict) -> str:
|
| 321 |
+
"""Get stationarity recommendations"""
|
| 322 |
+
# Check for keys before access
|
| 323 |
+
if 'overall' not in results or 'is_stationary' not in results['overall']:
|
| 324 |
+
return "Could not determine stationarity. Check data and test settings."
|
| 325 |
+
|
| 326 |
+
if results['overall']['is_stationary']:
|
| 327 |
+
return "Series is stationary, suitable for modelling"
|
| 328 |
+
else:
|
| 329 |
+
recommendations = []
|
| 330 |
+
|
| 331 |
+
# Check Hurst test results
|
| 332 |
+
if 'hurst' in results and 'exponent' in results['hurst']:
|
| 333 |
+
hurst_exponent = results['hurst']['exponent']
|
| 334 |
+
if not np.isnan(hurst_exponent) and hurst_exponent > 0.6:
|
| 335 |
+
recommendations.append("Apply differencing to remove trend")
|
| 336 |
+
|
| 337 |
+
# Check ADF test
|
| 338 |
+
if 'adf' in results and 'pvalue' in results['adf']:
|
| 339 |
+
adf_pvalue = results['adf']['pvalue']
|
| 340 |
+
if not np.isnan(adf_pvalue) and adf_pvalue > 0.1:
|
| 341 |
+
recommendations.append("Consider seasonal differencing due to non-stationarity")
|
| 342 |
+
|
| 343 |
+
if len(recommendations) == 0:
|
| 344 |
+
recommendations.append("Try logarithmic transformation and differencing")
|
| 345 |
+
|
| 346 |
+
return "; ".join(recommendations)
|
| 347 |
+
|
| 348 |
+
def _plot_stationarity_analysis(
|
| 349 |
+
self,
|
| 350 |
+
data: pd.DataFrame,
|
| 351 |
+
target_col: str,
|
| 352 |
+
results: Dict
|
| 353 |
+
) -> None:
|
| 354 |
+
"""Visualise stationarity analysis"""
|
| 355 |
+
series = data[target_col]
|
| 356 |
+
|
| 357 |
+
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
| 358 |
+
|
| 359 |
+
# 1. Original series
|
| 360 |
+
axes[0, 0].plot(series.index, series, linewidth=1)
|
| 361 |
+
axes[0, 0].set_title(f'Original Time Series: {target_col}')
|
| 362 |
+
axes[0, 0].set_xlabel('Date')
|
| 363 |
+
axes[0, 0].set_ylabel(target_col)
|
| 364 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 365 |
+
|
| 366 |
+
# 2. Rolling statistics
|
| 367 |
+
rolling_mean = series.rolling(window=365, center=True, min_periods=1).mean()
|
| 368 |
+
rolling_std = series.rolling(window=365, center=True, min_periods=1).std()
|
| 369 |
+
|
| 370 |
+
axes[0, 1].plot(series.index, series, label='Original series', alpha=0.7, linewidth=0.5)
|
| 371 |
+
axes[0, 1].plot(rolling_mean.index, rolling_mean, label='Rolling mean (365)', color='red', linewidth=2)
|
| 372 |
+
axes[0, 1].plot(rolling_std.index, rolling_std, label='Rolling STD (365)', color='green', linewidth=2)
|
| 373 |
+
axes[0, 1].set_title(f'Rolling Statistics: {target_col}')
|
| 374 |
+
axes[0, 1].set_xlabel('Date')
|
| 375 |
+
axes[0, 1].set_ylabel(target_col)
|
| 376 |
+
axes[0, 1].legend(fontsize=8)
|
| 377 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 378 |
+
|
| 379 |
+
# 3. ACF
|
| 380 |
+
plot_acf(series.dropna(), lags=50, ax=axes[0, 2], alpha=0.05)
|
| 381 |
+
axes[0, 2].set_title(f'Autocorrelation Function (ACF): {target_col}')
|
| 382 |
+
axes[0, 2].set_xlabel('Lag')
|
| 383 |
+
axes[0, 2].set_ylabel('Autocorrelation')
|
| 384 |
+
axes[0, 2].grid(True, alpha=0.3)
|
| 385 |
+
|
| 386 |
+
# 4. PACF
|
| 387 |
+
plot_pacf(series.dropna(), lags=50, ax=axes[1, 0], alpha=0.05)
|
| 388 |
+
axes[1, 0].set_title(f'Partial Autocorrelation Function (PACF): {target_col}')
|
| 389 |
+
axes[1, 0].set_xlabel('Lag')
|
| 390 |
+
axes[1, 0].set_ylabel('Partial Autocorrelation')
|
| 391 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 392 |
+
|
| 393 |
+
# 5. Histogram and Q-Q plot
|
| 394 |
+
axes[1, 1].hist(series.dropna(), bins=30, edgecolor='black', alpha=0.7, density=True)
|
| 395 |
+
axes[1, 1].set_title(f'Distribution: {target_col}')
|
| 396 |
+
axes[1, 1].set_xlabel('Value')
|
| 397 |
+
axes[1, 1].set_ylabel('Density')
|
| 398 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 399 |
+
|
| 400 |
+
# 6. Series differences
|
| 401 |
+
diff1 = series.diff(1).dropna()
|
| 402 |
+
axes[1, 2].plot(diff1.index, diff1, linewidth=0.5)
|
| 403 |
+
axes[1, 2].set_title(f'First Difference: {target_col}')
|
| 404 |
+
axes[1, 2].set_xlabel('Date')
|
| 405 |
+
axes[1, 2].set_ylabel(f'Δ{target_col}')
|
| 406 |
+
axes[1, 2].grid(True, alpha=0.3)
|
| 407 |
+
|
| 408 |
+
plt.suptitle(
|
| 409 |
+
f'Stationarity Analysis: {target_col}\n'
|
| 410 |
+
f'Stationary: {"✓ Yes" if results["overall"]["is_stationary"] else "✗ No"} '
|
| 411 |
+
f'(confidence: {results["overall"]["confidence"]})',
|
| 412 |
+
fontsize=14
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
plt.tight_layout()
|
| 416 |
+
plt.savefig(
|
| 417 |
+
f'{self.config.results_dir}/plots/stationarity_{target_col}.png',
|
| 418 |
+
dpi=300,
|
| 419 |
+
bbox_inches='tight'
|
| 420 |
+
)
|
| 421 |
+
plt.show()
|
| 422 |
+
|
| 423 |
+
def _log_test_results(self, target_col: str, results: Dict) -> None:
|
| 424 |
+
"""Log test results"""
|
| 425 |
+
logger.info("\nSTATIONARITY TEST RESULTS:")
|
| 426 |
+
logger.info("-" * 50)
|
| 427 |
+
|
| 428 |
+
# ADF test
|
| 429 |
+
adf = results['adf']
|
| 430 |
+
logger.info(f"Augmented Dickey-Fuller (ADF) test:")
|
| 431 |
+
logger.info(f" Statistic: {adf['statistic']:.4f}")
|
| 432 |
+
logger.info(f" p-value: {adf['pvalue']:.4f}")
|
| 433 |
+
logger.info(f" Stationary: {'✓ Yes' if adf['is_stationary'] else '✗ No'}")
|
| 434 |
+
|
| 435 |
+
# KPSS test
|
| 436 |
+
kpss_test = results['kpss']
|
| 437 |
+
if 'statistic' in kpss_test and not np.isnan(kpss_test['statistic']):
|
| 438 |
+
logger.info(f"\nKPSS test:")
|
| 439 |
+
logger.info(f" Statistic: {kpss_test['statistic']:.4f}")
|
| 440 |
+
logger.info(f" p-value: {kpss_test['pvalue']:.4f}")
|
| 441 |
+
logger.info(f" Stationary: {'✓ Yes' if kpss_test['is_stationary'] else '✗ No'}")
|
| 442 |
+
|
| 443 |
+
# Hurst exponent
|
| 444 |
+
hurst = results['hurst']
|
| 445 |
+
if 'exponent' in hurst and not np.isnan(hurst['exponent']):
|
| 446 |
+
logger.info(f"\nHurst exponent:")
|
| 447 |
+
logger.info(f" Value: {hurst['exponent']:.3f}")
|
| 448 |
+
logger.info(f" Trend type: {hurst['trend_type']}")
|
| 449 |
+
logger.info(f" Interpretation: {hurst.get('interpretation', '')}")
|
| 450 |
+
|
| 451 |
+
# Overall interpretation
|
| 452 |
+
logger.info(f"\nOVERALL CONCLUSION:")
|
| 453 |
+
logger.info("-" * 30)
|
| 454 |
+
logger.info(f"Stationary: {'✓ Yes' if results['overall']['is_stationary'] else '✗ No'}")
|
| 455 |
+
logger.info(f"Confidence: {results['overall']['confidence']}")
|
| 456 |
+
logger.info(f"Recommendation: {results['overall']['recommendation']}")
|
| 457 |
+
|
| 458 |
+
def _make_stationary(
|
| 459 |
+
self,
|
| 460 |
+
data: pd.DataFrame,
|
| 461 |
+
target_col: str,
|
| 462 |
+
results: Dict
|
| 463 |
+
) -> Optional[pd.DataFrame]:
|
| 464 |
+
"""
|
| 465 |
+
Transform series to stationary form
|
| 466 |
+
|
| 467 |
+
Parameters:
|
| 468 |
+
-----------
|
| 469 |
+
data : pd.DataFrame
|
| 470 |
+
Input data
|
| 471 |
+
target_col : str
|
| 472 |
+
Target variable
|
| 473 |
+
results : Dict
|
| 474 |
+
Stationarity test results
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
--------
|
| 478 |
+
Optional[pd.DataFrame]
|
| 479 |
+
Data with stationary series or None if transformation failed
|
| 480 |
+
"""
|
| 481 |
+
logger.info("\nTRANSFORMING TO STATIONARY FORM:")
|
| 482 |
+
logger.info("-" * 40)
|
| 483 |
+
|
| 484 |
+
data_processed = data.copy()
|
| 485 |
+
series = data_processed[target_col]
|
| 486 |
+
|
| 487 |
+
# Stationarisation methods in order of preference
|
| 488 |
+
methods = [
|
| 489 |
+
('diff', 'first-order differencing'),
|
| 490 |
+
('seasonal_diff', f'seasonal differencing (period={self.config.seasonal_period})'),
|
| 491 |
+
('log_diff', 'logarithmic differencing'),
|
| 492 |
+
('boxcox_diff', 'Box-Cox + differencing'),
|
| 493 |
+
('detrend', 'detrending'),
|
| 494 |
+
('combination', 'combined method')
|
| 495 |
+
]
|
| 496 |
+
|
| 497 |
+
best_method = None
|
| 498 |
+
best_series = None
|
| 499 |
+
best_pvalue = 1.0
|
| 500 |
+
best_stationary = False
|
| 501 |
+
|
| 502 |
+
for method, method_name in methods:
|
| 503 |
+
try:
|
| 504 |
+
if method == 'diff':
|
| 505 |
+
# Simple differencing
|
| 506 |
+
transformed = series.diff(1).dropna()
|
| 507 |
+
test_series = transformed
|
| 508 |
+
|
| 509 |
+
elif method == 'seasonal_diff':
|
| 510 |
+
# Seasonal differencing
|
| 511 |
+
transformed = series.diff(self.config.seasonal_period).dropna()
|
| 512 |
+
test_series = transformed
|
| 513 |
+
|
| 514 |
+
elif method == 'log_diff':
|
| 515 |
+
# Logarithmic differencing
|
| 516 |
+
if (series > 0).all():
|
| 517 |
+
log_series = np.log(series)
|
| 518 |
+
transformed = log_series.diff(1).dropna()
|
| 519 |
+
test_series = transformed
|
| 520 |
+
else:
|
| 521 |
+
# Shift for negative values
|
| 522 |
+
shift = abs(series.min()) + 1 if series.min() <= 0 else 0
|
| 523 |
+
log_series = np.log(series + shift)
|
| 524 |
+
transformed = log_series.diff(1).dropna()
|
| 525 |
+
test_series = transformed
|
| 526 |
+
|
| 527 |
+
elif method == 'boxcox_diff':
|
| 528 |
+
# Box-Cox transformation + differencing
|
| 529 |
+
try:
|
| 530 |
+
from scipy.stats import boxcox
|
| 531 |
+
# Add constant for positive values
|
| 532 |
+
shift = abs(series.min()) + 1 if series.min() <= 0 else 0
|
| 533 |
+
boxcox_series, _ = boxcox(series + shift)
|
| 534 |
+
transformed = pd.Series(boxcox_series, index=series.index).diff(1).dropna()
|
| 535 |
+
test_series = transformed
|
| 536 |
+
except:
|
| 537 |
+
continue
|
| 538 |
+
|
| 539 |
+
elif method == 'detrend':
|
| 540 |
+
# Linear detrending
|
| 541 |
+
x = np.arange(len(series))
|
| 542 |
+
y = series.values
|
| 543 |
+
coeffs = np.polyfit(x, y, 1)
|
| 544 |
+
trend = np.polyval(coeffs, x)
|
| 545 |
+
transformed = pd.Series(y - trend, index=series.index)
|
| 546 |
+
test_series = transformed
|
| 547 |
+
|
| 548 |
+
elif method == 'combination':
|
| 549 |
+
# Combined method: log + differencing + detrending
|
| 550 |
+
if (series > 0).all():
|
| 551 |
+
log_series = np.log(series)
|
| 552 |
+
diff_series = log_series.diff(1)
|
| 553 |
+
|
| 554 |
+
# Detrending residuals
|
| 555 |
+
x = np.arange(len(diff_series))
|
| 556 |
+
y = diff_series.values
|
| 557 |
+
valid_mask = ~np.isnan(y)
|
| 558 |
+
|
| 559 |
+
if valid_mask.sum() > 2:
|
| 560 |
+
coeffs = np.polyfit(x[valid_mask], y[valid_mask], 1)
|
| 561 |
+
trend = np.polyval(coeffs, x)
|
| 562 |
+
transformed = pd.Series(y - trend, index=series.index)
|
| 563 |
+
test_series = transformed.dropna()
|
| 564 |
+
else:
|
| 565 |
+
test_series = diff_series.dropna()
|
| 566 |
+
else:
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
+
# Check stationarity after transformation
|
| 570 |
+
if len(test_series) > 10:
|
| 571 |
+
adf_result = adfuller(test_series.dropna())
|
| 572 |
+
is_stationary = adf_result[1] < 0.05
|
| 573 |
+
pvalue = adf_result[1]
|
| 574 |
+
|
| 575 |
+
logger.info(f" Method: {method_name}")
|
| 576 |
+
logger.info(f" ADF p-value: {pvalue:.4f}")
|
| 577 |
+
logger.info(f" Stationary: {'✓ Yes' if is_stationary else '✗ No'}")
|
| 578 |
+
|
| 579 |
+
# Save best method
|
| 580 |
+
if is_stationary and pvalue < best_pvalue:
|
| 581 |
+
best_pvalue = pvalue
|
| 582 |
+
best_method = method
|
| 583 |
+
best_series = transformed
|
| 584 |
+
best_stationary = True
|
| 585 |
+
|
| 586 |
+
if pvalue < 0.01: # Very good result
|
| 587 |
+
break
|
| 588 |
+
|
| 589 |
+
except Exception as e:
|
| 590 |
+
logger.debug(f" Method {method} failed: {e}")
|
| 591 |
+
continue
|
| 592 |
+
|
| 593 |
+
# Save results
|
| 594 |
+
if best_series is not None:
|
| 595 |
+
new_col_name = f'{target_col}_stationary_{best_method}'
|
| 596 |
+
|
| 597 |
+
# Align indices
|
| 598 |
+
aligned_series = pd.Series(
|
| 599 |
+
best_series.values,
|
| 600 |
+
index=data_processed.index[-len(best_series):]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
data_processed[new_col_name] = aligned_series
|
| 604 |
+
|
| 605 |
+
self.transformed_series[target_col] = {
|
| 606 |
+
'method': best_method,
|
| 607 |
+
'new_column': new_col_name,
|
| 608 |
+
'pvalue': float(best_pvalue),
|
| 609 |
+
'is_stationary': best_stationary,
|
| 610 |
+
'original_shape': len(series),
|
| 611 |
+
'transformed_shape': len(best_series)
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
self.best_transformation[target_col] = best_method
|
| 615 |
+
|
| 616 |
+
logger.info(f"\n✓ Selected method: {best_method}")
|
| 617 |
+
logger.info(f" Saved as '{new_col_name}'")
|
| 618 |
+
logger.info(f" p-value: {best_pvalue:.4f}")
|
| 619 |
+
|
| 620 |
+
return data_processed
|
| 621 |
+
else:
|
| 622 |
+
logger.warning("✗ Could not find suitable transformation for stationarisation")
|
| 623 |
+
return None
|
| 624 |
+
|
| 625 |
+
def get_report(self) -> Dict:
|
| 626 |
+
"""Get stationarity report"""
|
| 627 |
+
return {
|
| 628 |
+
'test_results': self.test_results,
|
| 629 |
+
'transformed_series': self.transformed_series,
|
| 630 |
+
'best_transformations': self.best_transformation
|
| 631 |
+
}
|
streamlit/streamlit_app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
temp_data.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
validation/__init__.py
ADDED
|
File without changes
|
validation/data_validator.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 12: DATA VALIDATION
|
| 3 |
+
# ============================================
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
from venv import logger
|
| 9 |
+
|
| 10 |
+
from config.config import Config
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class DataValidator:
|
| 15 |
+
"""Class for data quality validation"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config: Config):
|
| 18 |
+
"""
|
| 19 |
+
Initialise data validator
|
| 20 |
+
|
| 21 |
+
Parameters:
|
| 22 |
+
-----------
|
| 23 |
+
config : Config
|
| 24 |
+
Experiment configuration
|
| 25 |
+
"""
|
| 26 |
+
self.config = config
|
| 27 |
+
self.validation_results = {}
|
| 28 |
+
self.quality_metrics = {}
|
| 29 |
+
self.issues_found = {}
|
| 30 |
+
|
| 31 |
+
def validate(
|
| 32 |
+
self,
|
| 33 |
+
data: pd.DataFrame,
|
| 34 |
+
stage: str = 'final',
|
| 35 |
+
rules: Dict = None,
|
| 36 |
+
detailed: bool = True
|
| 37 |
+
) -> Dict:
|
| 38 |
+
"""
|
| 39 |
+
Validate data quality
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
-----------
|
| 43 |
+
data : pd.DataFrame
|
| 44 |
+
Input data
|
| 45 |
+
stage : str
|
| 46 |
+
Validation stage: 'raw', 'processed', 'final'
|
| 47 |
+
rules : Dict, optional
|
| 48 |
+
Validation rules. If None, uses configuration defaults.
|
| 49 |
+
detailed : bool
|
| 50 |
+
Whether to perform detailed validation
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
--------
|
| 54 |
+
Dict
|
| 55 |
+
Validation results
|
| 56 |
+
"""
|
| 57 |
+
logger.info("\n" + "="*80)
|
| 58 |
+
logger.info(f"DATA VALIDATION ({stage.upper()})")
|
| 59 |
+
logger.info("="*80)
|
| 60 |
+
|
| 61 |
+
rules = rules or self.config.validation_rules
|
| 62 |
+
|
| 63 |
+
validation_results = {
|
| 64 |
+
'stage': stage,
|
| 65 |
+
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 66 |
+
'data_shape': list(data.shape),
|
| 67 |
+
'basic_checks': {},
|
| 68 |
+
'quality_metrics': {},
|
| 69 |
+
'issues': {},
|
| 70 |
+
'recommendations': [],
|
| 71 |
+
'overall_score': 0,
|
| 72 |
+
'status': 'PASS'
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Basic checks
|
| 76 |
+
validation_results['basic_checks'] = self._basic_checks(data, rules)
|
| 77 |
+
|
| 78 |
+
# Quality checks
|
| 79 |
+
validation_results['quality_metrics'] = self._quality_metrics(data, rules)
|
| 80 |
+
|
| 81 |
+
# Problem detection
|
| 82 |
+
if detailed:
|
| 83 |
+
validation_results['issues'] = self._find_issues(data, rules)
|
| 84 |
+
|
| 85 |
+
# Recommendation generation
|
| 86 |
+
validation_results['recommendations'] = self._generate_recommendations(
|
| 87 |
+
validation_results['basic_checks'],
|
| 88 |
+
validation_results['quality_metrics'],
|
| 89 |
+
validation_results['issues']
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Overall score calculation
|
| 93 |
+
validation_results['overall_score'] = self._calculate_overall_score(validation_results)
|
| 94 |
+
|
| 95 |
+
# Status determination
|
| 96 |
+
if validation_results['overall_score'] >= 80:
|
| 97 |
+
validation_results['status'] = 'PASS'
|
| 98 |
+
elif validation_results['overall_score'] >= 60:
|
| 99 |
+
validation_results['status'] = 'WARNING'
|
| 100 |
+
else:
|
| 101 |
+
validation_results['status'] = 'FAIL'
|
| 102 |
+
|
| 103 |
+
# Save results
|
| 104 |
+
self.validation_results[stage] = validation_results
|
| 105 |
+
self.quality_metrics[stage] = validation_results['quality_metrics']
|
| 106 |
+
|
| 107 |
+
# Log results
|
| 108 |
+
self._log_validation_results(validation_results)
|
| 109 |
+
|
| 110 |
+
return validation_results
|
| 111 |
+
|
| 112 |
+
def _basic_checks(self, data: pd.DataFrame, rules: Dict) -> Dict:
|
| 113 |
+
"""Basic data checks"""
|
| 114 |
+
checks = {}
|
| 115 |
+
|
| 116 |
+
# 1. Data size check
|
| 117 |
+
checks['min_rows'] = {
|
| 118 |
+
'value': len(data),
|
| 119 |
+
'threshold': rules.get('min_rows', 100),
|
| 120 |
+
'passed': len(data) >= rules.get('min_rows', 100)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# 2. Target variable presence check
|
| 124 |
+
target = self.config.target_column
|
| 125 |
+
checks['has_target'] = {
|
| 126 |
+
'value': target in data.columns,
|
| 127 |
+
'passed': target in data.columns
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# 3. Missing values check
|
| 131 |
+
missing_percentage = (data.isnull().sum().sum() / data.size) * 100
|
| 132 |
+
checks['missing_percentage'] = {
|
| 133 |
+
'value': missing_percentage,
|
| 134 |
+
'threshold': rules.get('max_missing_percentage', 30),
|
| 135 |
+
'passed': missing_percentage <= rules.get('max_missing_percentage', 30)
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# 4. Duplicates check
|
| 139 |
+
duplicate_count = data.duplicated().sum()
|
| 140 |
+
duplicate_percentage = (duplicate_count / len(data)) * 100
|
| 141 |
+
checks['duplicates'] = {
|
| 142 |
+
'value': duplicate_percentage,
|
| 143 |
+
'threshold': 5, # Maximum 5% duplicates
|
| 144 |
+
'passed': duplicate_percentage <= 5
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# 5. Data types check
|
| 148 |
+
numeric_count = len(data.select_dtypes(include=[np.number]).columns)
|
| 149 |
+
checks['numeric_features'] = {
|
| 150 |
+
'value': numeric_count,
|
| 151 |
+
'passed': numeric_count >= 1 # At least one numeric feature required
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
return checks
|
| 155 |
+
|
| 156 |
+
def _quality_metrics(self, data: pd.DataFrame, rules: Dict) -> Dict:
|
| 157 |
+
"""Data quality metrics"""
|
| 158 |
+
metrics = {}
|
| 159 |
+
|
| 160 |
+
# 1. Numeric features statistics
|
| 161 |
+
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
| 162 |
+
|
| 163 |
+
if len(numeric_cols) > 0:
|
| 164 |
+
numeric_stats = {}
|
| 165 |
+
for col in numeric_cols:
|
| 166 |
+
col_data = data[col].dropna()
|
| 167 |
+
if len(col_data) > 0:
|
| 168 |
+
numeric_stats[col] = {
|
| 169 |
+
'mean': float(col_data.mean()),
|
| 170 |
+
'std': float(col_data.std()),
|
| 171 |
+
'skewness': float(col_data.skew()),
|
| 172 |
+
'kurtosis': float(col_data.kurtosis()),
|
| 173 |
+
'zeros_percentage': float((col_data == 0).sum() / len(col_data) * 100),
|
| 174 |
+
'unique_percentage': float(col_data.nunique() / len(col_data) * 100)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
metrics['numeric_statistics'] = numeric_stats
|
| 178 |
+
|
| 179 |
+
# 2. Data stability (for time series)
|
| 180 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 181 |
+
stability_metrics = self._calculate_temporal_stability(data)
|
| 182 |
+
metrics['temporal_stability'] = stability_metrics
|
| 183 |
+
|
| 184 |
+
# 3. Feature informativeness
|
| 185 |
+
if self.config.target_column in data.columns:
|
| 186 |
+
informativeness = self._calculate_feature_informativeness(data)
|
| 187 |
+
metrics['feature_informativeness'] = informativeness
|
| 188 |
+
|
| 189 |
+
# 4. Target variable quality
|
| 190 |
+
target = self.config.target_column
|
| 191 |
+
if target in data.columns:
|
| 192 |
+
target_data = data[target].dropna()
|
| 193 |
+
if len(target_data) > 0:
|
| 194 |
+
target_metrics = {
|
| 195 |
+
'missing_percentage': float(target_data.isnull().sum() / len(data) * 100),
|
| 196 |
+
'unique_values': int(target_data.nunique()),
|
| 197 |
+
'is_constant': bool(target_data.nunique() <= 1),
|
| 198 |
+
'has_outliers': self._check_target_outliers(target_data),
|
| 199 |
+
'distribution_type': self._identify_distribution(target_data)
|
| 200 |
+
}
|
| 201 |
+
metrics['target_quality'] = target_metrics
|
| 202 |
+
|
| 203 |
+
# 5. Class balance (for classification) - not applicable here, but kept as placeholder
|
| 204 |
+
metrics['class_balance'] = {'note': 'Not applicable for regression'}
|
| 205 |
+
|
| 206 |
+
return metrics
|
| 207 |
+
|
| 208 |
+
def _calculate_temporal_stability(self, data: pd.DataFrame) -> Dict:
|
| 209 |
+
"""Calculate time series stability metrics"""
|
| 210 |
+
stability = {}
|
| 211 |
+
|
| 212 |
+
if not isinstance(data.index, pd.DatetimeIndex):
|
| 213 |
+
return stability
|
| 214 |
+
|
| 215 |
+
# Split into periods (e.g., by years)
|
| 216 |
+
if 'year' not in data.columns:
|
| 217 |
+
data_copy = data.copy()
|
| 218 |
+
data_copy['year'] = data_copy.index.year
|
| 219 |
+
else:
|
| 220 |
+
data_copy = data
|
| 221 |
+
|
| 222 |
+
years = sorted(data_copy['year'].unique())
|
| 223 |
+
|
| 224 |
+
if len(years) > 1:
|
| 225 |
+
# Statistics by years for numeric columns
|
| 226 |
+
year_stats = {}
|
| 227 |
+
for col in data.select_dtypes(include=[np.number]).columns[:5]: # First 5 columns
|
| 228 |
+
yearly_means = data_copy.groupby('year')[col].mean()
|
| 229 |
+
yearly_stds = data_copy.groupby('year')[col].std()
|
| 230 |
+
|
| 231 |
+
# Coefficient of variation between years
|
| 232 |
+
if yearly_means.std() > 0:
|
| 233 |
+
cv_between_years = yearly_means.std() / yearly_means.mean()
|
| 234 |
+
else:
|
| 235 |
+
cv_between_years = 0
|
| 236 |
+
|
| 237 |
+
year_stats[col] = {
|
| 238 |
+
'yearly_means': yearly_means.to_dict(),
|
| 239 |
+
'yearly_stds': yearly_stds.to_dict(),
|
| 240 |
+
'cv_between_years': float(cv_between_years),
|
| 241 |
+
'mean_stability': float(1 - cv_between_years) # 1 - CV, closer to 1 means more stable
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
stability['yearly_statistics'] = year_stats
|
| 245 |
+
|
| 246 |
+
# Check for time gaps
|
| 247 |
+
time_diff = pd.Series(data.index).diff().dropna()
|
| 248 |
+
if len(time_diff) > 0:
|
| 249 |
+
max_gap = time_diff.max()
|
| 250 |
+
avg_gap = time_diff.mean()
|
| 251 |
+
gap_std = time_diff.std()
|
| 252 |
+
|
| 253 |
+
stability['time_gaps'] = {
|
| 254 |
+
'max_gap_days': float(max_gap.days if hasattr(max_gap, 'days') else max_gap),
|
| 255 |
+
'avg_gap_days': float(avg_gap.days if hasattr(avg_gap, 'days') else avg_gap),
|
| 256 |
+
'gap_std': float(gap_std.days if hasattr(gap_std, 'days') else gap_std),
|
| 257 |
+
'has_irregular_gaps': gap_std > avg_gap * 0.5 # If standard deviation > 50% of mean
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# Seasonal stability
|
| 261 |
+
if len(data) > 365:
|
| 262 |
+
try:
|
| 263 |
+
# Analyse seasonal patterns
|
| 264 |
+
seasonal_stability = self._analyse_seasonal_stability(data)
|
| 265 |
+
stability['seasonal_stability'] = seasonal_stability
|
| 266 |
+
except:
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
return stability
|
| 270 |
+
|
| 271 |
+
def _analyse_seasonal_stability(self, data: pd.DataFrame) -> Dict:
|
| 272 |
+
"""Analyse seasonal patterns stability"""
|
| 273 |
+
if not isinstance(data.index, pd.DatetimeIndex):
|
| 274 |
+
return {}
|
| 275 |
+
|
| 276 |
+
# For simplicity, analyse only target variable
|
| 277 |
+
target = self.config.target_column
|
| 278 |
+
if target not in data.columns:
|
| 279 |
+
return {}
|
| 280 |
+
|
| 281 |
+
series = data[target]
|
| 282 |
+
|
| 283 |
+
# Split by years and compare seasonal patterns
|
| 284 |
+
data_copy = data.copy()
|
| 285 |
+
data_copy['year'] = data_copy.index.year
|
| 286 |
+
data_copy['month'] = data_copy.index.month
|
| 287 |
+
|
| 288 |
+
if 'year' in data_copy.columns and 'month' in data_copy.columns:
|
| 289 |
+
monthly_means = data_copy.groupby(['year', 'month'])[target].mean().unstack()
|
| 290 |
+
|
| 291 |
+
if not monthly_means.empty:
|
| 292 |
+
# Correlation between years
|
| 293 |
+
yearly_corr = monthly_means.corr().mean().mean()
|
| 294 |
+
|
| 295 |
+
# Variation between years
|
| 296 |
+
monthly_cv = monthly_means.std() / monthly_means.mean()
|
| 297 |
+
avg_monthly_cv = monthly_cv.mean()
|
| 298 |
+
|
| 299 |
+
return {
|
| 300 |
+
'yearly_correlation': float(yearly_corr),
|
| 301 |
+
'average_monthly_cv': float(avg_monthly_cv),
|
| 302 |
+
'seasonal_consistency': 'high' if yearly_corr > 0.8 and avg_monthly_cv < 0.3 else
|
| 303 |
+
'medium' if yearly_corr > 0.6 else 'low'
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
return {}
|
| 307 |
+
|
| 308 |
+
def _calculate_feature_informativeness(self, data: pd.DataFrame) -> Dict:
|
| 309 |
+
"""Calculate feature informativeness"""
|
| 310 |
+
informativeness = {}
|
| 311 |
+
|
| 312 |
+
target = self.config.target_column
|
| 313 |
+
if target not in data.columns:
|
| 314 |
+
return informativeness
|
| 315 |
+
|
| 316 |
+
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
| 317 |
+
numeric_cols = [col for col in numeric_cols if col != target]
|
| 318 |
+
|
| 319 |
+
for col in numeric_cols[:20]: # Limit number of features for analysis
|
| 320 |
+
try:
|
| 321 |
+
# Correlation with target variable
|
| 322 |
+
correlation = data[col].corr(data[target])
|
| 323 |
+
|
| 324 |
+
# Mutual information (approximated)
|
| 325 |
+
# For simplicity, use absolute correlation as informativeness measure
|
| 326 |
+
informativeness[col] = {
|
| 327 |
+
'correlation_with_target': float(correlation),
|
| 328 |
+
'abs_correlation': float(abs(correlation)),
|
| 329 |
+
'informativeness': 'high' if abs(correlation) > 0.5 else
|
| 330 |
+
'medium' if abs(correlation) > 0.3 else 'low'
|
| 331 |
+
}
|
| 332 |
+
except:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
return informativeness
|
| 336 |
+
|
| 337 |
+
def _check_target_outliers(self, target_series: pd.Series) -> Dict:
|
| 338 |
+
"""Check target variable for outliers"""
|
| 339 |
+
if len(target_series) < 10:
|
| 340 |
+
return {'has_outliers': False, 'outlier_percentage': 0}
|
| 341 |
+
|
| 342 |
+
q1 = target_series.quantile(0.25)
|
| 343 |
+
q3 = target_series.quantile(0.75)
|
| 344 |
+
iqr = q3 - q1
|
| 345 |
+
|
| 346 |
+
if iqr > 0:
|
| 347 |
+
lower_bound = q1 - 1.5 * iqr
|
| 348 |
+
upper_bound = q3 + 1.5 * iqr
|
| 349 |
+
|
| 350 |
+
outliers = target_series[(target_series < lower_bound) | (target_series > upper_bound)]
|
| 351 |
+
outlier_percentage = len(outliers) / len(target_series) * 100
|
| 352 |
+
|
| 353 |
+
return {
|
| 354 |
+
'has_outliers': len(outliers) > 0,
|
| 355 |
+
'outlier_count': int(len(outliers)),
|
| 356 |
+
'outlier_percentage': float(outlier_percentage),
|
| 357 |
+
'outlier_bounds': {'lower': float(lower_bound), 'upper': float(upper_bound)}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
return {'has_outliers': False, 'outlier_percentage': 0}
|
| 361 |
+
|
| 362 |
+
def _identify_distribution(self, series: pd.Series) -> str:
|
| 363 |
+
"""Identify distribution type"""
|
| 364 |
+
if len(series) < 30:
|
| 365 |
+
return 'insufficient_data'
|
| 366 |
+
|
| 367 |
+
skewness = series.skew()
|
| 368 |
+
kurtosis = series.kurtosis()
|
| 369 |
+
|
| 370 |
+
if abs(skewness) < 0.5 and abs(kurtosis) < 1:
|
| 371 |
+
return 'normal_like'
|
| 372 |
+
elif skewness > 1:
|
| 373 |
+
return 'right_skewed'
|
| 374 |
+
elif skewness < -1:
|
| 375 |
+
return 'left_skewed'
|
| 376 |
+
elif kurtosis > 3:
|
| 377 |
+
return 'heavy_tailed'
|
| 378 |
+
elif kurtosis < 2:
|
| 379 |
+
return 'light_tailed'
|
| 380 |
+
else:
|
| 381 |
+
return 'unknown'
|
| 382 |
+
|
| 383 |
+
def _find_issues(self, data: pd.DataFrame, rules: Dict) -> Dict:
|
| 384 |
+
"""Find data problems"""
|
| 385 |
+
issues = {
|
| 386 |
+
'critical': [],
|
| 387 |
+
'warning': [],
|
| 388 |
+
'info': []
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
# 1. Check missing values in important features
|
| 392 |
+
missing_info = data.isnull().sum()
|
| 393 |
+
high_missing_cols = missing_info[missing_info / len(data) * 100 > 20].index.tolist()
|
| 394 |
+
|
| 395 |
+
for col in high_missing_cols:
|
| 396 |
+
missing_pct = missing_info[col] / len(data) * 100
|
| 397 |
+
if missing_pct > 50:
|
| 398 |
+
issues['critical'].append(f"Column '{col}': {missing_pct:.1f}% missing values (critical)")
|
| 399 |
+
elif missing_pct > 20:
|
| 400 |
+
issues['warning'].append(f"Column '{col}': {missing_pct:.1f}% missing values")
|
| 401 |
+
|
| 402 |
+
# 2. Check constant features
|
| 403 |
+
for col in data.columns:
|
| 404 |
+
if data[col].nunique() <= 1:
|
| 405 |
+
issues['critical'].append(f"Column '{col}': constant value")
|
| 406 |
+
|
| 407 |
+
# 3. Check feature correlation with itself (lags)
|
| 408 |
+
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
| 409 |
+
for col in numeric_cols:
|
| 410 |
+
if '_lag_' in col or '_diff_' in col:
|
| 411 |
+
base_col = col.split('_lag_')[0] if '_lag_' in col else col.split('_diff_')[0]
|
| 412 |
+
if base_col in numeric_cols:
|
| 413 |
+
correlation = data[col].corr(data[base_col])
|
| 414 |
+
if pd.notna(correlation) and abs(correlation) > 0.95:
|
| 415 |
+
issues['info'].append(f"Column '{col}': high correlation with '{base_col}' ({correlation:.3f})")
|
| 416 |
+
|
| 417 |
+
# 4. Check time gaps
|
| 418 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 419 |
+
time_diff = pd.Series(data.index).diff().dropna()
|
| 420 |
+
if len(time_diff) > 0:
|
| 421 |
+
max_gap = time_diff.max()
|
| 422 |
+
if hasattr(max_gap, 'days') and max_gap.days > 30:
|
| 423 |
+
issues['warning'].append(f"Detected time gap: {max_gap.days} days")
|
| 424 |
+
|
| 425 |
+
# 5. Check target variable
|
| 426 |
+
target = self.config.target_column
|
| 427 |
+
if target in data.columns:
|
| 428 |
+
target_data = data[target].dropna()
|
| 429 |
+
if len(target_data) > 0:
|
| 430 |
+
if target_data.nunique() <= 1:
|
| 431 |
+
issues['critical'].append(f"Target variable '{target}': constant value")
|
| 432 |
+
|
| 433 |
+
# Check for outliers
|
| 434 |
+
outlier_check = self._check_target_outliers(target_data)
|
| 435 |
+
if outlier_check.get('has_outliers', False) and outlier_check.get('outlier_percentage', 0) > 10:
|
| 436 |
+
issues['warning'].append(f"Target variable '{target}': {outlier_check['outlier_percentage']:.1f}% outliers")
|
| 437 |
+
|
| 438 |
+
# 6. Check multicollinearity (simplified)
|
| 439 |
+
if len(numeric_cols) > 5:
|
| 440 |
+
corr_matrix = data[numeric_cols].corr().abs()
|
| 441 |
+
high_corr_pairs = []
|
| 442 |
+
|
| 443 |
+
for i in range(len(corr_matrix.columns)):
|
| 444 |
+
for j in range(i+1, len(corr_matrix.columns)):
|
| 445 |
+
if corr_matrix.iloc[i, j] > 0.9:
|
| 446 |
+
col1 = corr_matrix.columns[i]
|
| 447 |
+
col2 = corr_matrix.columns[j]
|
| 448 |
+
high_corr_pairs.append((col1, col2, corr_matrix.iloc[i, j]))
|
| 449 |
+
|
| 450 |
+
if len(high_corr_pairs) > 5:
|
| 451 |
+
issues['warning'].append(f"Detected multicollinearity: {len(high_corr_pairs)} pairs with correlation > 0.9")
|
| 452 |
+
|
| 453 |
+
return issues
|
| 454 |
+
|
| 455 |
+
def _generate_recommendations(
|
| 456 |
+
self,
|
| 457 |
+
basic_checks: Dict,
|
| 458 |
+
quality_metrics: Dict,
|
| 459 |
+
issues: Dict
|
| 460 |
+
) -> List[str]:
|
| 461 |
+
"""Generate data improvement recommendations"""
|
| 462 |
+
recommendations = []
|
| 463 |
+
|
| 464 |
+
# Recommendations based on basic checks
|
| 465 |
+
for check_name, check_info in basic_checks.items():
|
| 466 |
+
if not check_info.get('passed', True):
|
| 467 |
+
if check_name == 'min_rows':
|
| 468 |
+
recommendations.append(f"Increase data volume: current row count ({check_info['value']}) below minimum threshold ({check_info['threshold']})")
|
| 469 |
+
elif check_name == 'has_target':
|
| 470 |
+
recommendations.append(f"Add target variable '{self.config.target_column}' to data")
|
| 471 |
+
elif check_name == 'missing_percentage':
|
| 472 |
+
recommendations.append(f"Handle missing values: {check_info['value']:.1f}% missing exceeds threshold {check_info['threshold']}%")
|
| 473 |
+
elif check_name == 'duplicates':
|
| 474 |
+
recommendations.append(f"Remove duplicates: {check_info['value']:.1f}% duplicate rows")
|
| 475 |
+
|
| 476 |
+
# Recommendations based on issues
|
| 477 |
+
if issues.get('critical'):
|
| 478 |
+
recommendations.append("Resolve critical issues before using data")
|
| 479 |
+
|
| 480 |
+
if issues.get('warning'):
|
| 481 |
+
recommendations.append("Consider addressing warnings to improve data quality")
|
| 482 |
+
|
| 483 |
+
# Recommendations based on quality metrics
|
| 484 |
+
target_metrics = quality_metrics.get('target_quality', {})
|
| 485 |
+
if target_metrics.get('is_constant', False):
|
| 486 |
+
recommendations.append(f"Target variable '{self.config.target_column}' is constant, different target variable needed")
|
| 487 |
+
|
| 488 |
+
if target_metrics.get('has_outliers', {}).get('has_outliers', False):
|
| 489 |
+
outlier_pct = target_metrics['has_outliers'].get('outlier_percentage', 0)
|
| 490 |
+
if outlier_pct > 5:
|
| 491 |
+
recommendations.append(f"Handle outliers in target variable: {outlier_pct:.1f}% outliers")
|
| 492 |
+
|
| 493 |
+
# Time series stability recommendations
|
| 494 |
+
temporal_stability = quality_metrics.get('temporal_stability', {})
|
| 495 |
+
if temporal_stability.get('time_gaps', {}).get('has_irregular_gaps', False):
|
| 496 |
+
recommendations.append("Detected irregular time intervals, consider resampling")
|
| 497 |
+
|
| 498 |
+
return recommendations
|
| 499 |
+
|
| 500 |
+
def _calculate_overall_score(self, validation_results: Dict) -> float:
|
| 501 |
+
"""Calculate overall data quality score"""
|
| 502 |
+
score = 100
|
| 503 |
+
|
| 504 |
+
# Penalties for basic checks
|
| 505 |
+
basic_checks = validation_results.get('basic_checks', {})
|
| 506 |
+
for check_name, check_info in basic_checks.items():
|
| 507 |
+
if not check_info.get('passed', True):
|
| 508 |
+
if check_name == 'min_rows':
|
| 509 |
+
score -= 30
|
| 510 |
+
elif check_name == 'has_target':
|
| 511 |
+
score -= 50
|
| 512 |
+
elif check_name == 'missing_percentage':
|
| 513 |
+
missing_pct = check_info.get('value', 0)
|
| 514 |
+
if missing_pct > 50:
|
| 515 |
+
score -= 40
|
| 516 |
+
elif missing_pct > 20:
|
| 517 |
+
score -= 20
|
| 518 |
+
elif missing_pct > 5:
|
| 519 |
+
score -= 10
|
| 520 |
+
elif check_name == 'duplicates':
|
| 521 |
+
duplicate_pct = check_info.get('value', 0)
|
| 522 |
+
if duplicate_pct > 20:
|
| 523 |
+
score -= 30
|
| 524 |
+
elif duplicate_pct > 10:
|
| 525 |
+
score -= 15
|
| 526 |
+
elif duplicate_pct > 5:
|
| 527 |
+
score -= 5
|
| 528 |
+
|
| 529 |
+
# Penalties for issues
|
| 530 |
+
issues = validation_results.get('issues', {})
|
| 531 |
+
if issues.get('critical'):
|
| 532 |
+
score -= len(issues['critical']) * 20
|
| 533 |
+
|
| 534 |
+
if issues.get('warning'):
|
| 535 |
+
score -= len(issues['warning']) * 5
|
| 536 |
+
|
| 537 |
+
# Bonuses for good metrics
|
| 538 |
+
quality_metrics = validation_results.get('quality_metrics', {})
|
| 539 |
+
target_metrics = quality_metrics.get('target_quality', {})
|
| 540 |
+
|
| 541 |
+
if not target_metrics.get('is_constant', True):
|
| 542 |
+
score += 10
|
| 543 |
+
|
| 544 |
+
if target_metrics.get('missing_percentage', 100) < 1:
|
| 545 |
+
score += 5
|
| 546 |
+
|
| 547 |
+
# Limit score to 0-100 range
|
| 548 |
+
return max(0, min(100, score))
|
| 549 |
+
|
| 550 |
+
def _log_validation_results(self, validation_results: Dict) -> None:
|
| 551 |
+
"""Log validation results"""
|
| 552 |
+
stage = validation_results['stage']
|
| 553 |
+
status = validation_results['status']
|
| 554 |
+
score = validation_results['overall_score']
|
| 555 |
+
|
| 556 |
+
logger.info(f"VALIDATION RESULTS ({stage}):")
|
| 557 |
+
logger.info(f" Status: {status}")
|
| 558 |
+
logger.info(f" Overall score: {score}/100")
|
| 559 |
+
logger.info(f" Data shape: {validation_results['data_shape'][0]}x{validation_results['data_shape'][1]}")
|
| 560 |
+
|
| 561 |
+
# Basic checks
|
| 562 |
+
logger.info("\nBASIC CHECKS:")
|
| 563 |
+
for check_name, check_info in validation_results['basic_checks'].items():
|
| 564 |
+
status_icon = "✓" if check_info.get('passed', True) else "✗"
|
| 565 |
+
logger.info(f" {status_icon} {check_name}: {check_info.get('value', 'N/A')}")
|
| 566 |
+
|
| 567 |
+
# Issues
|
| 568 |
+
issues = validation_results['issues']
|
| 569 |
+
if any(issues.values()):
|
| 570 |
+
logger.info("\nDETECTED ISSUES:")
|
| 571 |
+
for severity, issue_list in issues.items():
|
| 572 |
+
if issue_list:
|
| 573 |
+
logger.info(f" {severity.upper()}:")
|
| 574 |
+
for issue in issue_list[:5]: # Show only first 5 issues of each type
|
| 575 |
+
logger.info(f" - {issue}")
|
| 576 |
+
if len(issue_list) > 5:
|
| 577 |
+
logger.info(f" ... and {len(issue_list) - 5} more issues")
|
| 578 |
+
else:
|
| 579 |
+
logger.info("\n✓ No issues detected")
|
| 580 |
+
|
| 581 |
+
# Recommendations
|
| 582 |
+
recommendations = validation_results['recommendations']
|
| 583 |
+
if recommendations:
|
| 584 |
+
logger.info("\nRECOMMENDATIONS:")
|
| 585 |
+
for i, rec in enumerate(recommendations, 1):
|
| 586 |
+
logger.info(f" {i}. {rec}")
|
| 587 |
+
|
| 588 |
+
# Conclusion
|
| 589 |
+
if status == 'PASS':
|
| 590 |
+
logger.info("\n✓ Data passed validation and is ready for use")
|
| 591 |
+
elif status == 'WARNING':
|
| 592 |
+
logger.info("\n⚠ Data requires attention, there are issues to address")
|
| 593 |
+
else:
|
| 594 |
+
logger.info("\n✗ Data requires significant improvement before use")
|
| 595 |
+
|
| 596 |
+
def generate_report(self, stage: str = 'final') -> Dict:
|
| 597 |
+
"""Generate detailed validation report"""
|
| 598 |
+
if stage not in self.validation_results:
|
| 599 |
+
return {}
|
| 600 |
+
|
| 601 |
+
report = self.validation_results[stage].copy()
|
| 602 |
+
|
| 603 |
+
# Add metadata
|
| 604 |
+
report['config'] = self.config.to_dict()
|
| 605 |
+
report['validator_version'] = '1.0'
|
| 606 |
+
report['generation_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 607 |
+
|
| 608 |
+
# Add detailed metrics
|
| 609 |
+
quality_metrics = report.get('quality_metrics', {})
|
| 610 |
+
|
| 611 |
+
if 'numeric_statistics' in quality_metrics:
|
| 612 |
+
# Numeric features summary
|
| 613 |
+
numeric_stats = quality_metrics['numeric_statistics']
|
| 614 |
+
report['numeric_summary'] = {
|
| 615 |
+
'total_numeric_features': len(numeric_stats),
|
| 616 |
+
'features_with_high_skewness': sum(1 for s in numeric_stats.values() if abs(s.get('skewness', 0)) > 1),
|
| 617 |
+
'features_with_high_kurtosis': sum(1 for s in numeric_stats.values() if abs(s.get('kurtosis', 0)) > 3),
|
| 618 |
+
'features_with_many_zeros': sum(1 for s in numeric_stats.values() if s.get('zeros_percentage', 0) > 50)
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
return report
|
| 622 |
+
|
| 623 |
+
def save_report(self, stage: str = 'final', path: str = None) -> None:
|
| 624 |
+
"""Save validation report to file"""
|
| 625 |
+
if stage not in self.validation_results:
|
| 626 |
+
logger.warning(f"Report for stage '{stage}' not found")
|
| 627 |
+
return
|
| 628 |
+
|
| 629 |
+
report = self.generate_report(stage)
|
| 630 |
+
|
| 631 |
+
if path is None:
|
| 632 |
+
path = f'{self.config.results_dir}/reports/validation_report_{stage}.json'
|
| 633 |
+
|
| 634 |
+
# Create directory if needed
|
| 635 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 636 |
+
|
| 637 |
+
# Custom JSON encoder
|
| 638 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 639 |
+
def default(self, obj):
|
| 640 |
+
if isinstance(obj, (np.integer, np.floating)):
|
| 641 |
+
if np.isnan(obj):
|
| 642 |
+
return None
|
| 643 |
+
return float(obj)
|
| 644 |
+
elif isinstance(obj, np.bool_):
|
| 645 |
+
return bool(obj)
|
| 646 |
+
elif isinstance(obj, np.ndarray):
|
| 647 |
+
return obj.tolist()
|
| 648 |
+
elif isinstance(obj, pd.Timestamp):
|
| 649 |
+
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
| 650 |
+
return super().default(obj)
|
| 651 |
+
|
| 652 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 653 |
+
json.dump(report, f, indent=4, ensure_ascii=False, cls=NumpyEncoder)
|
| 654 |
+
|
| 655 |
+
logger.info(f"✓ Validation report saved: {path}")
|
visualization/__init__.py
ADDED
|
File without changes
|
visualization/visualization_manager.py
ADDED
|
@@ -0,0 +1,1462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# CLASS 13: VISUALISATION MANAGER (UPDATED)
|
| 3 |
+
# ============================================
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
from scipy.stats import gaussian_kde
|
| 15 |
+
import matplotlib
|
| 16 |
+
matplotlib.use('Agg') # Use non-display backend
|
| 17 |
+
|
| 18 |
+
from config.config import Config
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
# Logging setup
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VisualisationManager:
|
| 27 |
+
"""Class for managing all visualisations"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, config: Config):
|
| 30 |
+
"""
|
| 31 |
+
Initialise visualisation manager
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
-----------
|
| 35 |
+
config : Config
|
| 36 |
+
Experiment configuration
|
| 37 |
+
"""
|
| 38 |
+
self.config = config
|
| 39 |
+
self.plots_generated = {}
|
| 40 |
+
self.plot_files = {}
|
| 41 |
+
self.figure_count = 0
|
| 42 |
+
|
| 43 |
+
# Create directory structure for saving plots
|
| 44 |
+
self._create_directory_structure()
|
| 45 |
+
|
| 46 |
+
def _create_directory_structure(self) -> None:
|
| 47 |
+
"""Create directory structure for saving plots"""
|
| 48 |
+
base_dir = self.config.results_dir
|
| 49 |
+
|
| 50 |
+
# Main plot directories
|
| 51 |
+
self.plots_dir = os.path.join(base_dir, "plots")
|
| 52 |
+
self.correlations_dir = os.path.join(base_dir, "plots", "correlations")
|
| 53 |
+
self.distributions_dir = os.path.join(base_dir, "plots", "distributions")
|
| 54 |
+
self.features_dir = os.path.join(base_dir, "plots", "features")
|
| 55 |
+
self.time_series_dir = os.path.join(base_dir, "plots", "time_series")
|
| 56 |
+
self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing")
|
| 57 |
+
self.summary_dir = os.path.join(base_dir, "plots", "summary")
|
| 58 |
+
self.reports_dir = os.path.join(base_dir, "reports")
|
| 59 |
+
|
| 60 |
+
# Create directories
|
| 61 |
+
directories = [
|
| 62 |
+
self.plots_dir,
|
| 63 |
+
self.correlations_dir,
|
| 64 |
+
self.distributions_dir,
|
| 65 |
+
self.features_dir,
|
| 66 |
+
self.time_series_dir,
|
| 67 |
+
self.preprocessing_dir,
|
| 68 |
+
self.summary_dir,
|
| 69 |
+
self.reports_dir
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
for directory in directories:
|
| 73 |
+
os.makedirs(directory, exist_ok=True)
|
| 74 |
+
logger.debug(f"Created directory: {directory}")
|
| 75 |
+
|
| 76 |
+
def _save_figure(self, fig: plt.Figure, filename: str,
|
| 77 |
+
subdirectory: str = None, dpi: int = 300) -> str:
|
| 78 |
+
"""
|
| 79 |
+
Save plot and close it
|
| 80 |
+
|
| 81 |
+
Parameters:
|
| 82 |
+
-----------
|
| 83 |
+
fig : matplotlib.figure.Figure
|
| 84 |
+
Plot figure object
|
| 85 |
+
filename : str
|
| 86 |
+
Filename for saving
|
| 87 |
+
subdirectory : str, optional
|
| 88 |
+
Subdirectory for saving
|
| 89 |
+
dpi : int
|
| 90 |
+
Save quality
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
--------
|
| 94 |
+
str : full path to saved file
|
| 95 |
+
"""
|
| 96 |
+
if not filename.endswith('.png'):
|
| 97 |
+
filename = f"{filename}.png"
|
| 98 |
+
|
| 99 |
+
if subdirectory:
|
| 100 |
+
save_dir = os.path.join(self.plots_dir, subdirectory)
|
| 101 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 102 |
+
else:
|
| 103 |
+
save_dir = self.plots_dir
|
| 104 |
+
|
| 105 |
+
filepath = os.path.join(save_dir, filename)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white')
|
| 109 |
+
logger.info(f"✓ Plot saved: {filepath}")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"✗ Error saving plot {filename}: {e}")
|
| 112 |
+
filepath = None
|
| 113 |
+
|
| 114 |
+
# Close plot without display
|
| 115 |
+
plt.close(fig)
|
| 116 |
+
|
| 117 |
+
return filepath
|
| 118 |
+
|
| 119 |
+
# ============================================
|
| 120 |
+
# MAIN VISUALISATION METHODS
|
| 121 |
+
# ============================================
|
| 122 |
+
|
| 123 |
+
def create_summary_dashboard(
|
| 124 |
+
self,
|
| 125 |
+
data: pd.DataFrame,
|
| 126 |
+
preprocessing_stages: Dict = None,
|
| 127 |
+
filename: str = "summary_dashboard"
|
| 128 |
+
) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Create summary visualisation dashboard
|
| 131 |
+
|
| 132 |
+
Parameters:
|
| 133 |
+
-----------
|
| 134 |
+
data : pd.DataFrame
|
| 135 |
+
Data for visualisation
|
| 136 |
+
preprocessing_stages : Dict, optional
|
| 137 |
+
Preprocessing stages information
|
| 138 |
+
filename : str
|
| 139 |
+
Filename for saving
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
--------
|
| 143 |
+
str : path to saved file or None if error
|
| 144 |
+
"""
|
| 145 |
+
logger.info("\n" + "="*80)
|
| 146 |
+
logger.info("CREATING SUMMARY DASHBOARD")
|
| 147 |
+
logger.info("="*80)
|
| 148 |
+
|
| 149 |
+
target_col = self.config.target_column
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Create large dashboard
|
| 153 |
+
fig = plt.figure(figsize=(20, 24))
|
| 154 |
+
gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3)
|
| 155 |
+
|
| 156 |
+
# 1. Time series of target variable
|
| 157 |
+
ax1 = fig.add_subplot(gs[0, :2])
|
| 158 |
+
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
|
| 159 |
+
ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7)
|
| 160 |
+
ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold')
|
| 161 |
+
ax1.set_xlabel('Date', fontsize=10)
|
| 162 |
+
ax1.set_ylabel(target_col, fontsize=10)
|
| 163 |
+
ax1.grid(True, alpha=0.3)
|
| 164 |
+
ax1.tick_params(axis='x', rotation=45)
|
| 165 |
+
else:
|
| 166 |
+
ax1.text(0.5, 0.5, 'No time series data available',
|
| 167 |
+
ha='center', va='center', transform=ax1.transAxes)
|
| 168 |
+
|
| 169 |
+
# 2. Target variable distribution
|
| 170 |
+
ax2 = fig.add_subplot(gs[0, 2:])
|
| 171 |
+
if target_col in data.columns:
|
| 172 |
+
values = data[target_col].dropna()
|
| 173 |
+
if len(values) > 0:
|
| 174 |
+
ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green')
|
| 175 |
+
ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold')
|
| 176 |
+
ax2.set_xlabel(target_col, fontsize=10)
|
| 177 |
+
ax2.set_ylabel('Frequency', fontsize=10)
|
| 178 |
+
ax2.grid(True, alpha=0.3)
|
| 179 |
+
else:
|
| 180 |
+
ax2.text(0.5, 0.5, 'No data for distribution',
|
| 181 |
+
ha='center', va='center', transform=ax2.transAxes)
|
| 182 |
+
|
| 183 |
+
# 3. Correlation matrix (top features)
|
| 184 |
+
ax3 = fig.add_subplot(gs[1, :])
|
| 185 |
+
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
| 186 |
+
if len(numeric_cols) > 1:
|
| 187 |
+
display_cols = list(numeric_cols[:15])
|
| 188 |
+
if target_col not in display_cols and target_col in data.columns:
|
| 189 |
+
display_cols = [target_col] + [c for c in display_cols if c != target_col][:14]
|
| 190 |
+
|
| 191 |
+
corr_matrix = data[display_cols].corr()
|
| 192 |
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
| 193 |
+
|
| 194 |
+
im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
|
| 195 |
+
ax3.set_title('Correlation Matrix (Top 15 Features)',
|
| 196 |
+
fontsize=12, fontweight='bold')
|
| 197 |
+
ax3.set_xticks(range(len(display_cols)))
|
| 198 |
+
ax3.set_yticks(range(len(display_cols)))
|
| 199 |
+
ax3.set_xticklabels(display_cols, rotation=90, fontsize=8)
|
| 200 |
+
ax3.set_yticklabels(display_cols, fontsize=8)
|
| 201 |
+
plt.colorbar(im, ax=ax3, shrink=0.8)
|
| 202 |
+
|
| 203 |
+
# 4. Seasonal patterns
|
| 204 |
+
ax4 = fig.add_subplot(gs[2, :2])
|
| 205 |
+
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
|
| 206 |
+
data_copy = data.copy()
|
| 207 |
+
data_copy['month'] = data_copy.index.month
|
| 208 |
+
|
| 209 |
+
monthly_avg = data_copy.groupby('month')[target_col].mean()
|
| 210 |
+
colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg)))
|
| 211 |
+
ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black')
|
| 212 |
+
ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold')
|
| 213 |
+
ax4.set_xlabel('Month', fontsize=10)
|
| 214 |
+
ax4.set_ylabel(f'Average {target_col}', fontsize=10)
|
| 215 |
+
ax4.set_xticks(range(1, 13))
|
| 216 |
+
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
| 217 |
+
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
| 218 |
+
ax4.set_xticklabels(month_names)
|
| 219 |
+
ax4.grid(True, alpha=0.3, axis='y')
|
| 220 |
+
|
| 221 |
+
# 5. Weekly patterns
|
| 222 |
+
ax5 = fig.add_subplot(gs[2, 2:])
|
| 223 |
+
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
|
| 224 |
+
data_copy = data.copy()
|
| 225 |
+
data_copy['dayofweek'] = data_copy.index.dayofweek
|
| 226 |
+
|
| 227 |
+
daily_avg = data_copy.groupby('dayofweek')[target_col].mean()
|
| 228 |
+
colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg)))
|
| 229 |
+
ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black')
|
| 230 |
+
ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold')
|
| 231 |
+
ax5.set_xlabel('Day of Week', fontsize=10)
|
| 232 |
+
ax5.set_ylabel(f'Average {target_col}', fontsize=10)
|
| 233 |
+
ax5.set_xticks(range(7))
|
| 234 |
+
ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
|
| 235 |
+
ax5.grid(True, alpha=0.3, axis='y')
|
| 236 |
+
|
| 237 |
+
# 6. Trend and seasonality
|
| 238 |
+
ax6 = fig.add_subplot(gs[3, :])
|
| 239 |
+
if target_col in data.columns and len(data) > 30:
|
| 240 |
+
try:
|
| 241 |
+
window_size = min(365, len(data) // 10)
|
| 242 |
+
if window_size >= 7:
|
| 243 |
+
rolling_mean = data[target_col].rolling(window=window_size, center=True).mean()
|
| 244 |
+
rolling_std = data[target_col].rolling(window=window_size, center=True).std()
|
| 245 |
+
|
| 246 |
+
ax6.plot(data.index, data[target_col], alpha=0.5,
|
| 247 |
+
label='Original Series', linewidth=0.5, color='blue')
|
| 248 |
+
ax6.plot(rolling_mean.index, rolling_mean,
|
| 249 |
+
label=f'Rolling Mean ({window_size} days)',
|
| 250 |
+
color='red', linewidth=2)
|
| 251 |
+
ax6.fill_between(rolling_mean.index,
|
| 252 |
+
rolling_mean - rolling_std,
|
| 253 |
+
rolling_mean + rolling_std,
|
| 254 |
+
alpha=0.2, color='red')
|
| 255 |
+
|
| 256 |
+
ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold')
|
| 257 |
+
ax6.set_xlabel('Date', fontsize=10)
|
| 258 |
+
ax6.set_ylabel(target_col, fontsize=10)
|
| 259 |
+
ax6.legend(fontsize=9, loc='upper left')
|
| 260 |
+
ax6.grid(True, alpha=0.3)
|
| 261 |
+
else:
|
| 262 |
+
ax6.text(0.5, 0.5, 'Insufficient data for trend analysis',
|
| 263 |
+
ha='center', va='center', transform=ax6.transAxes)
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.warning(f"Error plotting trend: {e}")
|
| 266 |
+
ax6.text(0.5, 0.5, 'Error plotting trend',
|
| 267 |
+
ha='center', va='center', transform=ax6.transAxes)
|
| 268 |
+
|
| 269 |
+
# 7. Preprocessing statistics
|
| 270 |
+
if preprocessing_stages:
|
| 271 |
+
ax7 = fig.add_subplot(gs[4, :2])
|
| 272 |
+
|
| 273 |
+
stages = list(preprocessing_stages.keys())
|
| 274 |
+
values = list(preprocessing_stages.values())
|
| 275 |
+
|
| 276 |
+
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages)))
|
| 277 |
+
bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black')
|
| 278 |
+
ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold')
|
| 279 |
+
ax7.set_xlabel('Processing Stage', fontsize=10)
|
| 280 |
+
ax7.set_ylabel('Value', fontsize=10)
|
| 281 |
+
ax7.set_xticks(range(len(stages)))
|
| 282 |
+
ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages],
|
| 283 |
+
rotation=45, ha='right', fontsize=9)
|
| 284 |
+
ax7.grid(True, alpha=0.3, axis='y')
|
| 285 |
+
|
| 286 |
+
# Add values on bars
|
| 287 |
+
for bar, value in zip(bars, values):
|
| 288 |
+
height = bar.get_height()
|
| 289 |
+
ax7.text(bar.get_x() + bar.get_width()/2., height,
|
| 290 |
+
f'{value:.2f}', ha='center', va='bottom', fontsize=8)
|
| 291 |
+
|
| 292 |
+
# 8. Data information
|
| 293 |
+
ax8 = fig.add_subplot(gs[4, 2:])
|
| 294 |
+
ax8.axis('off')
|
| 295 |
+
|
| 296 |
+
info_text = []
|
| 297 |
+
info_text.append("GENERAL CHARACTERISTICS:")
|
| 298 |
+
info_text.append(f"• Number of records: {len(data):,}")
|
| 299 |
+
info_text.append(f"• Number of features: {len(data.columns)}")
|
| 300 |
+
|
| 301 |
+
if isinstance(data.index, pd.DatetimeIndex):
|
| 302 |
+
info_text.append(f"• Period: {data.index.min().strftime('%Y-%m-%d')} - "
|
| 303 |
+
f"{data.index.max().strftime('%Y-%m-%d')}")
|
| 304 |
+
info_text.append(f"• Days of data: {(data.index.max() - data.index.min()).days}")
|
| 305 |
+
|
| 306 |
+
if target_col in data.columns:
|
| 307 |
+
target_stats = data[target_col].describe()
|
| 308 |
+
info_text.append(f"\nTARGET VARIABLE '{target_col}':")
|
| 309 |
+
info_text.append(f"• Mean: {target_stats['mean']:.2f}")
|
| 310 |
+
info_text.append(f"• Standard deviation: {target_stats['std']:.2f}")
|
| 311 |
+
info_text.append(f"• Minimum: {target_stats['min']:.2f}")
|
| 312 |
+
info_text.append(f"• 25%: {target_stats['25%']:.2f}")
|
| 313 |
+
info_text.append(f"• 50% (median): {target_stats['50%']:.2f}")
|
| 314 |
+
info_text.append(f"• 75%: {target_stats['75%']:.2f}")
|
| 315 |
+
info_text.append(f"• Maximum: {target_stats['max']:.2f}")
|
| 316 |
+
|
| 317 |
+
info_text.append(f"\nDATA TYPES:")
|
| 318 |
+
for dtype, count in data.dtypes.value_counts().items():
|
| 319 |
+
info_text.append(f"• {dtype}: {count} columns")
|
| 320 |
+
|
| 321 |
+
missing_info = data.isnull().sum()
|
| 322 |
+
missing_total = missing_info.sum()
|
| 323 |
+
missing_percent = missing_total / data.size * 100
|
| 324 |
+
info_text.append(f"\nMISSING VALUES:")
|
| 325 |
+
info_text.append(f"• Total missing: {missing_total:,}")
|
| 326 |
+
info_text.append(f"• Missing percentage: {missing_percent:.2f}%")
|
| 327 |
+
|
| 328 |
+
if missing_total > 0:
|
| 329 |
+
top_missing = missing_info.nlargest(5)
|
| 330 |
+
info_text.append(f"• Top 5 columns with missing values:")
|
| 331 |
+
for col, count in top_missing.items():
|
| 332 |
+
percent = count / len(data) * 100
|
| 333 |
+
info_text.append(f" {col}: {count} ({percent:.1f}%)")
|
| 334 |
+
|
| 335 |
+
ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes,
|
| 336 |
+
fontsize=8, verticalalignment='top', fontfamily='monospace')
|
| 337 |
+
|
| 338 |
+
# 9. Autocorrelation plot
|
| 339 |
+
ax9 = fig.add_subplot(gs[5, :2])
|
| 340 |
+
if target_col in data.columns:
|
| 341 |
+
try:
|
| 342 |
+
series = data[target_col].dropna()
|
| 343 |
+
if len(series) > 50:
|
| 344 |
+
plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05)
|
| 345 |
+
ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold')
|
| 346 |
+
ax9.set_xlabel('Lag', fontsize=10)
|
| 347 |
+
ax9.set_ylabel('Autocorrelation', fontsize=10)
|
| 348 |
+
ax9.grid(True, alpha=0.3)
|
| 349 |
+
else:
|
| 350 |
+
ax9.text(0.5, 0.5, 'Insufficient data for ACF',
|
| 351 |
+
ha='center', va='center', transform=ax9.transAxes)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logger.warning(f"Error plotting ACF: {e}")
|
| 354 |
+
ax9.text(0.5, 0.5, 'Error calculating ACF',
|
| 355 |
+
ha='center', va='center', transform=ax9.transAxes)
|
| 356 |
+
|
| 357 |
+
# 10. Partial autocorrelation plot
|
| 358 |
+
ax10 = fig.add_subplot(gs[5, 2:])
|
| 359 |
+
if target_col in data.columns:
|
| 360 |
+
try:
|
| 361 |
+
series = data[target_col].dropna()
|
| 362 |
+
if len(series) > 50:
|
| 363 |
+
plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05)
|
| 364 |
+
ax10.set_title('Partial Autocorrelation Function (PACF)',
|
| 365 |
+
fontsize=12, fontweight='bold')
|
| 366 |
+
ax10.set_xlabel('Lag', fontsize=10)
|
| 367 |
+
ax10.set_ylabel('Partial Autocorrelation', fontsize=10)
|
| 368 |
+
ax10.grid(True, alpha=0.3)
|
| 369 |
+
else:
|
| 370 |
+
ax10.text(0.5, 0.5, 'Insufficient data for PACF',
|
| 371 |
+
ha='center', va='center', transform=ax10.transAxes)
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.warning(f"Error plotting PACF: {e}")
|
| 374 |
+
ax10.text(0.5, 0.5, 'Error calculating PACF',
|
| 375 |
+
ha='center', va='center', transform=ax10.transAxes)
|
| 376 |
+
|
| 377 |
+
plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98)
|
| 378 |
+
plt.tight_layout()
|
| 379 |
+
|
| 380 |
+
# Save
|
| 381 |
+
filepath = self._save_figure(fig, filename, "summary")
|
| 382 |
+
self.plot_files['summary_dashboard'] = filepath
|
| 383 |
+
return filepath
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"Error creating summary dashboard: {e}")
|
| 387 |
+
return None
|
| 388 |
+
|
| 389 |
+
# ============================================
|
| 390 |
+
# SPECIFIC METHODS FOR SAVING YOUR PLOTS
|
| 391 |
+
# ============================================
|
| 392 |
+
|
| 393 |
+
def save_data_split_plot(self, filename: str = "data_split.png") -> str:
|
| 394 |
+
"""
|
| 395 |
+
Save data split plot
|
| 396 |
+
|
| 397 |
+
Parameters:
|
| 398 |
+
-----------
|
| 399 |
+
filename : str
|
| 400 |
+
Filename for saving
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
--------
|
| 404 |
+
str : path to saved file
|
| 405 |
+
"""
|
| 406 |
+
try:
|
| 407 |
+
fig = plt.gcf() # Get current figure
|
| 408 |
+
filepath = self._save_figure(fig, filename, "time_series")
|
| 409 |
+
self.plot_files['data_split'] = filepath
|
| 410 |
+
return filepath
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"Error saving data_split plot: {e}")
|
| 413 |
+
return None
|
| 414 |
+
|
| 415 |
+
def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str:
|
| 416 |
+
"""
|
| 417 |
+
Save feature selection correlation plot
|
| 418 |
+
|
| 419 |
+
Parameters:
|
| 420 |
+
-----------
|
| 421 |
+
filename : str
|
| 422 |
+
Filename for saving
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
--------
|
| 426 |
+
str : path to saved file
|
| 427 |
+
"""
|
| 428 |
+
try:
|
| 429 |
+
fig = plt.gcf() # Get current figure
|
| 430 |
+
filepath = self._save_figure(fig, filename, "correlations")
|
| 431 |
+
self.plot_files['feature_selection_correlation'] = filepath
|
| 432 |
+
return filepath
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Error saving feature_selection_correlation plot: {e}")
|
| 435 |
+
return None
|
| 436 |
+
|
| 437 |
+
def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str:
|
| 438 |
+
"""
|
| 439 |
+
Save missing values analysis plot
|
| 440 |
+
|
| 441 |
+
Parameters:
|
| 442 |
+
-----------
|
| 443 |
+
filename : str
|
| 444 |
+
Filename for saving
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
--------
|
| 448 |
+
str : path to saved file
|
| 449 |
+
"""
|
| 450 |
+
try:
|
| 451 |
+
fig = plt.gcf() # Get current figure
|
| 452 |
+
filepath = self._save_figure(fig, filename, "preprocessing")
|
| 453 |
+
self.plot_files['missing_values_analysis'] = filepath
|
| 454 |
+
return filepath
|
| 455 |
+
except Exception as e:
|
| 456 |
+
logger.error(f"Error saving missing_values_analysis plot: {e}")
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str:
|
| 460 |
+
"""
|
| 461 |
+
Save outlier handling results plot
|
| 462 |
+
|
| 463 |
+
Parameters:
|
| 464 |
+
-----------
|
| 465 |
+
filename : str
|
| 466 |
+
Filename for saving
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
--------
|
| 470 |
+
str : path to saved file
|
| 471 |
+
"""
|
| 472 |
+
try:
|
| 473 |
+
fig = plt.gcf() # Get current figure
|
| 474 |
+
filepath = self._save_figure(fig, filename, "preprocessing")
|
| 475 |
+
self.plot_files['outlier_handling_results'] = filepath
|
| 476 |
+
return filepath
|
| 477 |
+
except Exception as e:
|
| 478 |
+
logger.error(f"Error saving outlier_handling_results plot: {e}")
|
| 479 |
+
return None
|
| 480 |
+
|
| 481 |
+
def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str:
|
| 482 |
+
"""
|
| 483 |
+
Save outliers analysis plot
|
| 484 |
+
|
| 485 |
+
Parameters:
|
| 486 |
+
-----------
|
| 487 |
+
filename : str
|
| 488 |
+
Filename for saving
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
--------
|
| 492 |
+
str : path to saved file
|
| 493 |
+
"""
|
| 494 |
+
try:
|
| 495 |
+
fig = plt.gcf() # Get current figure
|
| 496 |
+
filepath = self._save_figure(fig, filename, "preprocessing")
|
| 497 |
+
self.plot_files['outliers_analysis'] = filepath
|
| 498 |
+
return filepath
|
| 499 |
+
except Exception as e:
|
| 500 |
+
logger.error(f"Error saving outliers_analysis plot: {e}")
|
| 501 |
+
return None
|
| 502 |
+
|
| 503 |
+
def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str:
|
| 504 |
+
"""
|
| 505 |
+
Save scaling results plot
|
| 506 |
+
|
| 507 |
+
Parameters:
|
| 508 |
+
-----------
|
| 509 |
+
filename : str
|
| 510 |
+
Filename for saving
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
--------
|
| 514 |
+
str : path to saved file
|
| 515 |
+
"""
|
| 516 |
+
try:
|
| 517 |
+
fig = plt.gcf() # Get current figure
|
| 518 |
+
filepath = self._save_figure(fig, filename, "preprocessing")
|
| 519 |
+
self.plot_files['scaling_results'] = filepath
|
| 520 |
+
return filepath
|
| 521 |
+
except Exception as e:
|
| 522 |
+
logger.error(f"Error saving scaling_results plot: {e}")
|
| 523 |
+
return None
|
| 524 |
+
|
| 525 |
+
def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str:
|
| 526 |
+
"""
|
| 527 |
+
Save stationarity analysis plot
|
| 528 |
+
|
| 529 |
+
Parameters:
|
| 530 |
+
-----------
|
| 531 |
+
filename : str
|
| 532 |
+
Filename for saving
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
--------
|
| 536 |
+
str : path to saved file
|
| 537 |
+
"""
|
| 538 |
+
try:
|
| 539 |
+
fig = plt.gcf() # Get current figure
|
| 540 |
+
filepath = self._save_figure(fig, filename, "time_series")
|
| 541 |
+
self.plot_files['stationarity_analysis'] = filepath
|
| 542 |
+
return filepath
|
| 543 |
+
except Exception as e:
|
| 544 |
+
logger.error(f"Error saving stationarity_analysis plot: {e}")
|
| 545 |
+
return None
|
| 546 |
+
|
| 547 |
+
def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str:
|
| 548 |
+
"""
|
| 549 |
+
Save temporal outliers plot
|
| 550 |
+
|
| 551 |
+
Parameters:
|
| 552 |
+
-----------
|
| 553 |
+
filename : str
|
| 554 |
+
Filename for saving
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
--------
|
| 558 |
+
str : path to saved file
|
| 559 |
+
"""
|
| 560 |
+
try:
|
| 561 |
+
fig = plt.gcf() # Get current figure
|
| 562 |
+
filepath = self._save_figure(fig, filename, "time_series")
|
| 563 |
+
self.plot_files['temporal_outliers'] = filepath
|
| 564 |
+
return filepath
|
| 565 |
+
except Exception as e:
|
| 566 |
+
logger.error(f"Error saving temporal_outliers plot: {e}")
|
| 567 |
+
return None
|
| 568 |
+
|
| 569 |
+
# ============================================
|
| 570 |
+
# UNIVERSAL METHOD FOR SAVING ANY PLOT
|
| 571 |
+
# ============================================
|
| 572 |
+
|
| 573 |
+
def save_current_plot(self, filename: str, subdirectory: str = None) -> str:
|
| 574 |
+
"""
|
| 575 |
+
Universal method for saving current plot
|
| 576 |
+
|
| 577 |
+
Parameters:
|
| 578 |
+
-----------
|
| 579 |
+
filename : str
|
| 580 |
+
Filename for saving
|
| 581 |
+
subdirectory : str, optional
|
| 582 |
+
Subdirectory for saving
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
--------
|
| 586 |
+
str : path to saved file
|
| 587 |
+
"""
|
| 588 |
+
try:
|
| 589 |
+
fig = plt.gcf() # Get current figure
|
| 590 |
+
filepath = self._save_figure(fig, filename, subdirectory)
|
| 591 |
+
|
| 592 |
+
# Save plot information
|
| 593 |
+
plot_key = filename.replace('.png', '').replace('.jpg', '')
|
| 594 |
+
self.plot_files[plot_key] = filepath
|
| 595 |
+
|
| 596 |
+
return filepath
|
| 597 |
+
except Exception as e:
|
| 598 |
+
logger.error(f"Error saving plot {filename}: {e}")
|
| 599 |
+
return None
|
| 600 |
+
|
| 601 |
+
# ============================================
|
| 602 |
+
# ADDITIONAL VISUALISATION METHODS
|
| 603 |
+
# ============================================
|
| 604 |
+
|
| 605 |
+
def create_feature_importance_plot(
|
| 606 |
+
self,
|
| 607 |
+
feature_importance: Dict,
|
| 608 |
+
top_n: int = 20,
|
| 609 |
+
filename: str = "feature_importance"
|
| 610 |
+
) -> str:
|
| 611 |
+
"""
|
| 612 |
+
Create feature importance plot
|
| 613 |
+
|
| 614 |
+
Parameters:
|
| 615 |
+
-----------
|
| 616 |
+
feature_importance : Dict
|
| 617 |
+
Dictionary with feature importance
|
| 618 |
+
top_n : int
|
| 619 |
+
Number of top features to display
|
| 620 |
+
filename : str
|
| 621 |
+
Filename for saving
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
--------
|
| 625 |
+
str : path to saved file or None if error
|
| 626 |
+
"""
|
| 627 |
+
if not feature_importance:
|
| 628 |
+
logger.warning("No feature importance data for visualisation")
|
| 629 |
+
return None
|
| 630 |
+
|
| 631 |
+
try:
|
| 632 |
+
# Convert to Series and sort
|
| 633 |
+
importance_series = pd.Series(feature_importance).sort_values(ascending=False)
|
| 634 |
+
top_features = importance_series.head(top_n)
|
| 635 |
+
|
| 636 |
+
# Create plot
|
| 637 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 638 |
+
|
| 639 |
+
y_pos = np.arange(len(top_features))
|
| 640 |
+
colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features)))
|
| 641 |
+
|
| 642 |
+
bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black')
|
| 643 |
+
ax.set_yticks(y_pos)
|
| 644 |
+
ax.set_yticklabels(top_features.index, fontsize=10)
|
| 645 |
+
ax.invert_yaxis()
|
| 646 |
+
ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold')
|
| 647 |
+
ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold')
|
| 648 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 649 |
+
|
| 650 |
+
# Add values on bars
|
| 651 |
+
for i, (bar, value) in enumerate(zip(bars, top_features.values)):
|
| 652 |
+
width = bar.get_width()
|
| 653 |
+
ax.text(width * 1.01, bar.get_y() + bar.get_height()/2,
|
| 654 |
+
f'{value:.4f}', va='center', fontsize=9, fontweight='bold')
|
| 655 |
+
|
| 656 |
+
# Add additional information
|
| 657 |
+
plt.text(0.02, 0.98, f'Total features: {len(importance_series)}',
|
| 658 |
+
transform=fig.transFigure, fontsize=9, verticalalignment='top')
|
| 659 |
+
|
| 660 |
+
plt.tight_layout()
|
| 661 |
+
|
| 662 |
+
# Save
|
| 663 |
+
filepath = self._save_figure(fig, filename, "features")
|
| 664 |
+
self.plot_files['feature_importance'] = filepath
|
| 665 |
+
return filepath
|
| 666 |
+
|
| 667 |
+
except Exception as e:
|
| 668 |
+
logger.error(f"Error creating feature importance plot: {e}")
|
| 669 |
+
return None
|
| 670 |
+
|
| 671 |
+
def create_correlation_heatmap(
|
| 672 |
+
self,
|
| 673 |
+
data: pd.DataFrame,
|
| 674 |
+
top_n: int = 20,
|
| 675 |
+
filename: str = "correlation_heatmap"
|
| 676 |
+
) -> Tuple[str, Optional[str]]:
|
| 677 |
+
"""
|
| 678 |
+
Create correlation heatmap
|
| 679 |
+
|
| 680 |
+
Parameters:
|
| 681 |
+
-----------
|
| 682 |
+
data : pd.DataFrame
|
| 683 |
+
Data for analysis
|
| 684 |
+
top_n : int
|
| 685 |
+
Number of top features to display
|
| 686 |
+
filename : str
|
| 687 |
+
Filename for saving
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
--------
|
| 691 |
+
Tuple[str, Optional[str]]:
|
| 692 |
+
(path to main heatmap, path to target correlation heatmap)
|
| 693 |
+
"""
|
| 694 |
+
target_col = self.config.target_column
|
| 695 |
+
|
| 696 |
+
try:
|
| 697 |
+
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
| 698 |
+
|
| 699 |
+
if len(numeric_cols) < 2:
|
| 700 |
+
logger.warning("Insufficient numeric features for correlation analysis")
|
| 701 |
+
return None, None
|
| 702 |
+
|
| 703 |
+
# Create two heatmaps
|
| 704 |
+
|
| 705 |
+
# 1. Main correlation heatmap between all features
|
| 706 |
+
main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename)
|
| 707 |
+
|
| 708 |
+
# 2. Target correlation heatmap
|
| 709 |
+
target_filepath = None
|
| 710 |
+
if target_col in data.columns and target_col in numeric_cols:
|
| 711 |
+
target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename)
|
| 712 |
+
|
| 713 |
+
return main_filepath, target_filepath
|
| 714 |
+
|
| 715 |
+
except Exception as e:
|
| 716 |
+
logger.error(f"Error creating correlation heatmap: {e}")
|
| 717 |
+
return None, None
|
| 718 |
+
|
| 719 |
+
def _create_main_correlation_heatmap(
|
| 720 |
+
self,
|
| 721 |
+
data: pd.DataFrame,
|
| 722 |
+
numeric_cols: List[str],
|
| 723 |
+
top_n: int,
|
| 724 |
+
filename: str
|
| 725 |
+
) -> str:
|
| 726 |
+
"""Create main correlation heatmap"""
|
| 727 |
+
# Limit number of features for better readability
|
| 728 |
+
if len(numeric_cols) > top_n:
|
| 729 |
+
# Select features with highest variance
|
| 730 |
+
variances = data[numeric_cols].var().sort_values(ascending=False)
|
| 731 |
+
selected_cols = variances.head(top_n).index.tolist()
|
| 732 |
+
else:
|
| 733 |
+
selected_cols = numeric_cols
|
| 734 |
+
|
| 735 |
+
# Calculate correlation
|
| 736 |
+
corr_matrix = data[selected_cols].corr()
|
| 737 |
+
|
| 738 |
+
fig, ax = plt.subplots(figsize=(14, 12))
|
| 739 |
+
|
| 740 |
+
# Mask for upper triangle
|
| 741 |
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
| 742 |
+
|
| 743 |
+
# Create heatmap
|
| 744 |
+
sns.heatmap(
|
| 745 |
+
corr_matrix,
|
| 746 |
+
annot=True,
|
| 747 |
+
fmt='.2f',
|
| 748 |
+
cmap='coolwarm',
|
| 749 |
+
center=0,
|
| 750 |
+
square=True,
|
| 751 |
+
mask=mask,
|
| 752 |
+
cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'},
|
| 753 |
+
linewidths=0.5,
|
| 754 |
+
linecolor='white',
|
| 755 |
+
ax=ax,
|
| 756 |
+
annot_kws={'size': 8}
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})',
|
| 760 |
+
fontsize=14, fontweight='bold', pad=20)
|
| 761 |
+
|
| 762 |
+
plt.tight_layout()
|
| 763 |
+
|
| 764 |
+
# Save
|
| 765 |
+
filepath = self._save_figure(fig, filename, "correlations")
|
| 766 |
+
self.plot_files['correlation_heatmap_main'] = filepath
|
| 767 |
+
return filepath
|
| 768 |
+
|
| 769 |
+
def _create_target_correlation_heatmap(
|
| 770 |
+
self,
|
| 771 |
+
data: pd.DataFrame,
|
| 772 |
+
target_col: str,
|
| 773 |
+
numeric_cols: List[str],
|
| 774 |
+
filename: str
|
| 775 |
+
) -> str:
|
| 776 |
+
"""Create target correlation heatmap"""
|
| 777 |
+
# Calculate correlations with target variable
|
| 778 |
+
correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False)
|
| 779 |
+
|
| 780 |
+
# Exclude target variable itself
|
| 781 |
+
correlations = correlations[correlations.index != target_col]
|
| 782 |
+
|
| 783 |
+
# Take top 15 features
|
| 784 |
+
top_features = correlations.head(15)
|
| 785 |
+
|
| 786 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 787 |
+
|
| 788 |
+
colors = ['red' if x < 0 else 'green' for x in top_features.values]
|
| 789 |
+
bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black')
|
| 790 |
+
|
| 791 |
+
ax.set_yticks(range(len(top_features)))
|
| 792 |
+
ax.set_yticklabels(top_features.index, fontsize=10)
|
| 793 |
+
ax.invert_yaxis()
|
| 794 |
+
ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold')
|
| 795 |
+
ax.set_title(f'Feature Correlations with Target Variable "{target_col}"',
|
| 796 |
+
fontsize=14, fontweight='bold', pad=20)
|
| 797 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 798 |
+
ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
|
| 799 |
+
|
| 800 |
+
# Add values on bars
|
| 801 |
+
for bar, value in zip(bars, top_features.values):
|
| 802 |
+
width = bar.get_width()
|
| 803 |
+
ax.text(width + (0.01 if width >= 0 else -0.04),
|
| 804 |
+
bar.get_y() + bar.get_height()/2,
|
| 805 |
+
f'{value:.3f}',
|
| 806 |
+
va='center',
|
| 807 |
+
ha='left' if width >= 0 else 'right',
|
| 808 |
+
fontsize=9,
|
| 809 |
+
fontweight='bold',
|
| 810 |
+
color='black')
|
| 811 |
+
|
| 812 |
+
plt.tight_layout()
|
| 813 |
+
|
| 814 |
+
# Save
|
| 815 |
+
target_filename = f"{filename}_with_target"
|
| 816 |
+
filepath = self._save_figure(fig, target_filename, "correlations")
|
| 817 |
+
self.plot_files['correlation_with_target'] = filepath
|
| 818 |
+
return filepath
|
| 819 |
+
|
| 820 |
+
def create_distribution_comparison(
|
| 821 |
+
self,
|
| 822 |
+
original_data: pd.DataFrame,
|
| 823 |
+
processed_data: pd.DataFrame,
|
| 824 |
+
columns: List[str] = None,
|
| 825 |
+
max_columns: int = 12,
|
| 826 |
+
filename: str = "distribution_comparison"
|
| 827 |
+
) -> str:
|
| 828 |
+
"""
|
| 829 |
+
Compare distributions before and after processing
|
| 830 |
+
|
| 831 |
+
Parameters:
|
| 832 |
+
-----------
|
| 833 |
+
original_data : pd.DataFrame
|
| 834 |
+
Original data
|
| 835 |
+
processed_data : pd.DataFrame
|
| 836 |
+
Processed data
|
| 837 |
+
columns : List[str], optional
|
| 838 |
+
List of columns to compare
|
| 839 |
+
max_columns : int
|
| 840 |
+
Maximum number of columns to display
|
| 841 |
+
filename : str
|
| 842 |
+
Filename for saving
|
| 843 |
+
|
| 844 |
+
Returns:
|
| 845 |
+
--------
|
| 846 |
+
str : path to saved file or None if error
|
| 847 |
+
"""
|
| 848 |
+
try:
|
| 849 |
+
if columns is None:
|
| 850 |
+
# Select numeric columns common to both datasets
|
| 851 |
+
numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns
|
| 852 |
+
numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns
|
| 853 |
+
common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed))
|
| 854 |
+
|
| 855 |
+
# Sort by variance in original data
|
| 856 |
+
variances = original_data[common_cols].var().sort_values(ascending=False)
|
| 857 |
+
columns = variances.head(max_columns).index.tolist()
|
| 858 |
+
|
| 859 |
+
n_cols = min(4, len(columns))
|
| 860 |
+
n_rows = (len(columns) + n_cols - 1) // n_cols
|
| 861 |
+
|
| 862 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5))
|
| 863 |
+
fig.suptitle('Distribution Comparison Before and After Processing',
|
| 864 |
+
fontsize=16, fontweight='bold', y=0.98)
|
| 865 |
+
|
| 866 |
+
if n_rows == 1 and n_cols == 1:
|
| 867 |
+
axes = np.array([axes])
|
| 868 |
+
axes = axes.flat if hasattr(axes, 'flat') else [axes]
|
| 869 |
+
|
| 870 |
+
for idx, col in enumerate(columns):
|
| 871 |
+
if idx >= len(axes):
|
| 872 |
+
break
|
| 873 |
+
|
| 874 |
+
ax = axes[idx]
|
| 875 |
+
|
| 876 |
+
if col in original_data.columns and col in processed_data.columns:
|
| 877 |
+
original_values = original_data[col].dropna()
|
| 878 |
+
processed_values = processed_data[col].dropna()
|
| 879 |
+
|
| 880 |
+
if len(original_values) > 0 and len(processed_values) > 0:
|
| 881 |
+
# Use common bins for comparison
|
| 882 |
+
all_values = pd.concat([original_values, processed_values])
|
| 883 |
+
bins = np.histogram_bin_edges(all_values, bins=30)
|
| 884 |
+
|
| 885 |
+
# Histograms
|
| 886 |
+
ax.hist(original_values, bins=bins, alpha=0.5,
|
| 887 |
+
label='Before Processing', density=True, color='blue')
|
| 888 |
+
ax.hist(processed_values, bins=bins, alpha=0.5,
|
| 889 |
+
label='After Processing', density=True, color='orange')
|
| 890 |
+
|
| 891 |
+
# Add KDE
|
| 892 |
+
try:
|
| 893 |
+
if len(original_values) > 10:
|
| 894 |
+
kde_original = gaussian_kde(original_values)
|
| 895 |
+
x_range = np.linspace(original_values.min(), original_values.max(), 100)
|
| 896 |
+
ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8)
|
| 897 |
+
|
| 898 |
+
if len(processed_values) > 10:
|
| 899 |
+
kde_processed = gaussian_kde(processed_values)
|
| 900 |
+
x_range = np.linspace(processed_values.min(), processed_values.max(), 100)
|
| 901 |
+
ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8)
|
| 902 |
+
except:
|
| 903 |
+
pass
|
| 904 |
+
|
| 905 |
+
# Add statistics
|
| 906 |
+
stats_text = []
|
| 907 |
+
if len(original_values) > 0:
|
| 908 |
+
stats_text.append(f"Before: μ={original_values.mean():.2f}, σ={original_values.std():.2f}")
|
| 909 |
+
if len(processed_values) > 0:
|
| 910 |
+
stats_text.append(f"After: μ={processed_values.mean():.2f}, σ={processed_values.std():.2f}")
|
| 911 |
+
|
| 912 |
+
if stats_text:
|
| 913 |
+
ax.text(0.02, 0.98, '\n'.join(stats_text),
|
| 914 |
+
transform=ax.transAxes, fontsize=8,
|
| 915 |
+
verticalalignment='top',
|
| 916 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 917 |
+
|
| 918 |
+
ax.set_title(f'{col}', fontsize=11, fontweight='bold')
|
| 919 |
+
ax.set_xlabel('Value', fontsize=9)
|
| 920 |
+
ax.set_ylabel('Density', fontsize=9)
|
| 921 |
+
ax.legend(fontsize=8)
|
| 922 |
+
ax.grid(True, alpha=0.3)
|
| 923 |
+
else:
|
| 924 |
+
ax.text(0.5, 0.5, 'No data',
|
| 925 |
+
ha='center', va='center', transform=ax.transAxes)
|
| 926 |
+
else:
|
| 927 |
+
ax.text(0.5, 0.5, 'Column not found',
|
| 928 |
+
ha='center', va='center', transform=ax.transAxes)
|
| 929 |
+
|
| 930 |
+
# Hide unused subplots
|
| 931 |
+
for idx in range(len(columns), len(axes)):
|
| 932 |
+
axes[idx].set_visible(False)
|
| 933 |
+
|
| 934 |
+
plt.tight_layout()
|
| 935 |
+
|
| 936 |
+
# Save
|
| 937 |
+
filepath = self._save_figure(fig, filename, "distributions")
|
| 938 |
+
self.plot_files['distribution_comparison'] = filepath
|
| 939 |
+
return filepath
|
| 940 |
+
|
| 941 |
+
except Exception as e:
|
| 942 |
+
logger.error(f"Error creating distribution comparison: {e}")
|
| 943 |
+
return None
|
| 944 |
+
|
| 945 |
+
def create_time_series_decomposition_plot(
|
| 946 |
+
self,
|
| 947 |
+
decomposition_result: Dict,
|
| 948 |
+
filename: str = "time_series_decomposition"
|
| 949 |
+
) -> str:
|
| 950 |
+
"""
|
| 951 |
+
Visualise time series decomposition
|
| 952 |
+
|
| 953 |
+
Parameters:
|
| 954 |
+
-----------
|
| 955 |
+
decomposition_result : Dict
|
| 956 |
+
Decomposition results
|
| 957 |
+
filename : str
|
| 958 |
+
Filename for saving
|
| 959 |
+
|
| 960 |
+
Returns:
|
| 961 |
+
--------
|
| 962 |
+
str : path to saved file or None if error
|
| 963 |
+
"""
|
| 964 |
+
target_col = self.config.target_column
|
| 965 |
+
|
| 966 |
+
try:
|
| 967 |
+
fig, axes = plt.subplots(4, 1, figsize=(14, 10))
|
| 968 |
+
fig.suptitle(f'Time Series Decomposition: {target_col}',
|
| 969 |
+
fontsize=16, fontweight='bold', y=0.98)
|
| 970 |
+
|
| 971 |
+
# Original series
|
| 972 |
+
if 'observed' in decomposition_result:
|
| 973 |
+
observed = decomposition_result['observed']
|
| 974 |
+
axes[0].plot(observed, color='blue', linewidth=1.5)
|
| 975 |
+
axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold')
|
| 976 |
+
axes[0].grid(True, alpha=0.3)
|
| 977 |
+
axes[0].set_title('Original Time Series', fontsize=12)
|
| 978 |
+
|
| 979 |
+
# Trend
|
| 980 |
+
if 'trend' in decomposition_result and decomposition_result['trend'] is not None:
|
| 981 |
+
trend = decomposition_result['trend']
|
| 982 |
+
axes[1].plot(trend, color='red', linewidth=2)
|
| 983 |
+
axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold')
|
| 984 |
+
axes[1].grid(True, alpha=0.3)
|
| 985 |
+
axes[1].set_title('Trend Component', fontsize=12)
|
| 986 |
+
|
| 987 |
+
# Seasonality
|
| 988 |
+
if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None:
|
| 989 |
+
seasonal = decomposition_result['seasonal']
|
| 990 |
+
axes[2].plot(seasonal, color='green', linewidth=1.5)
|
| 991 |
+
axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold')
|
| 992 |
+
axes[2].grid(True, alpha=0.3)
|
| 993 |
+
axes[2].set_title('Seasonal Component', fontsize=12)
|
| 994 |
+
|
| 995 |
+
# Residuals
|
| 996 |
+
if 'residual' in decomposition_result and decomposition_result['residual'] is not None:
|
| 997 |
+
residual = decomposition_result['residual']
|
| 998 |
+
axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7)
|
| 999 |
+
axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold')
|
| 1000 |
+
axes[3].set_xlabel('Date', fontsize=11, fontweight='bold')
|
| 1001 |
+
axes[3].grid(True, alpha=0.3)
|
| 1002 |
+
axes[3].set_title('Residual Component', fontsize=12)
|
| 1003 |
+
|
| 1004 |
+
# Add residual statistics
|
| 1005 |
+
if len(residual) > 0:
|
| 1006 |
+
stats_text = (f"Mean: {residual.mean():.4f}\n"
|
| 1007 |
+
f"Std: {residual.std():.4f}\n"
|
| 1008 |
+
f"Min: {residual.min():.4f}\n"
|
| 1009 |
+
f"Max: {residual.max():.4f}")
|
| 1010 |
+
axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes,
|
| 1011 |
+
fontsize=8, verticalalignment='top',
|
| 1012 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 1013 |
+
|
| 1014 |
+
plt.tight_layout()
|
| 1015 |
+
|
| 1016 |
+
# Save
|
| 1017 |
+
filepath = self._save_figure(fig, filename, "time_series")
|
| 1018 |
+
self.plot_files['time_series_decomposition'] = filepath
|
| 1019 |
+
return filepath
|
| 1020 |
+
|
| 1021 |
+
except Exception as e:
|
| 1022 |
+
logger.error(f"Error creating time series decomposition: {e}")
|
| 1023 |
+
return None
|
| 1024 |
+
|
| 1025 |
+
def create_data_quality_report(
|
| 1026 |
+
self,
|
| 1027 |
+
validation_results: Dict,
|
| 1028 |
+
filename: str = "data_quality_report"
|
| 1029 |
+
) -> str:
|
| 1030 |
+
"""
|
| 1031 |
+
Create visual data quality report
|
| 1032 |
+
|
| 1033 |
+
Parameters:
|
| 1034 |
+
-----------
|
| 1035 |
+
validation_results : Dict
|
| 1036 |
+
Validation results
|
| 1037 |
+
filename : str
|
| 1038 |
+
Filename for saving
|
| 1039 |
+
|
| 1040 |
+
Returns:
|
| 1041 |
+
--------
|
| 1042 |
+
str : path to saved file or None if error
|
| 1043 |
+
"""
|
| 1044 |
+
try:
|
| 1045 |
+
fig = plt.figure(figsize=(16, 12))
|
| 1046 |
+
fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98)
|
| 1047 |
+
|
| 1048 |
+
# Use GridSpec for more complex layout
|
| 1049 |
+
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
|
| 1050 |
+
|
| 1051 |
+
# 1. Quality radar chart (top left)
|
| 1052 |
+
ax1 = fig.add_subplot(gs[0, 0], projection='polar')
|
| 1053 |
+
|
| 1054 |
+
categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness']
|
| 1055 |
+
|
| 1056 |
+
# Extract values from validation results
|
| 1057 |
+
if 'quality_metrics' in validation_results:
|
| 1058 |
+
values = [
|
| 1059 |
+
validation_results['quality_metrics'].get('size_score', 0.5),
|
| 1060 |
+
validation_results['quality_metrics'].get('missing_score', 0.5),
|
| 1061 |
+
validation_results['quality_metrics'].get('duplicates_score', 0.5),
|
| 1062 |
+
validation_results['quality_metrics'].get('stability_score', 0.5),
|
| 1063 |
+
validation_results['quality_metrics'].get('informativeness_score', 0.5)
|
| 1064 |
+
]
|
| 1065 |
+
else:
|
| 1066 |
+
values = [0.8, 0.7, 0.9, 0.6, 0.8]
|
| 1067 |
+
|
| 1068 |
+
N = len(categories)
|
| 1069 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 1070 |
+
angles += angles[:1]
|
| 1071 |
+
values += values[:1]
|
| 1072 |
+
|
| 1073 |
+
ax1.plot(angles, values, 'o-', linewidth=2, color='blue')
|
| 1074 |
+
ax1.fill(angles, values, alpha=0.25, color='blue')
|
| 1075 |
+
ax1.set_xticks(angles[:-1])
|
| 1076 |
+
ax1.set_xticklabels(categories, fontsize=10)
|
| 1077 |
+
ax1.set_ylim(0, 1)
|
| 1078 |
+
ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold')
|
| 1079 |
+
ax1.grid(True)
|
| 1080 |
+
|
| 1081 |
+
# 2. Check status (top right)
|
| 1082 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 1083 |
+
|
| 1084 |
+
basic_checks = validation_results.get('basic_checks', {})
|
| 1085 |
+
checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False))
|
| 1086 |
+
checks_total = len(basic_checks)
|
| 1087 |
+
checks_failed = checks_total - checks_passed
|
| 1088 |
+
|
| 1089 |
+
if checks_total > 0:
|
| 1090 |
+
colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B',
|
| 1091 |
+
'#FF6B6B' if checks_failed > 0 else '#4CAF50']
|
| 1092 |
+
bars = ax2.bar(['Passed', 'Failed'],
|
| 1093 |
+
[checks_passed, checks_failed],
|
| 1094 |
+
color=colors, edgecolor='black')
|
| 1095 |
+
|
| 1096 |
+
ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}',
|
| 1097 |
+
fontsize=12, fontweight='bold')
|
| 1098 |
+
ax2.set_ylabel('Number of Checks', fontsize=10)
|
| 1099 |
+
ax2.grid(True, alpha=0.3, axis='y')
|
| 1100 |
+
|
| 1101 |
+
# Add values on bars
|
| 1102 |
+
for bar, value in zip(bars, [checks_passed, checks_failed]):
|
| 1103 |
+
height = bar.get_height()
|
| 1104 |
+
ax2.text(bar.get_x() + bar.get_width()/2., height,
|
| 1105 |
+
f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
| 1106 |
+
else:
|
| 1107 |
+
ax2.text(0.5, 0.5, 'No check data available',
|
| 1108 |
+
ha='center', va='center', transform=ax2.transAxes)
|
| 1109 |
+
ax2.set_title('Basic Checks', fontsize=12, fontweight='bold')
|
| 1110 |
+
|
| 1111 |
+
# 3. Overall score (top right)
|
| 1112 |
+
ax3 = fig.add_subplot(gs[0, 2])
|
| 1113 |
+
|
| 1114 |
+
overall_score = validation_results.get('overall_score', 0)
|
| 1115 |
+
status = validation_results.get('status', 'UNKNOWN')
|
| 1116 |
+
|
| 1117 |
+
# Score pie chart
|
| 1118 |
+
sizes = [overall_score, 100 - overall_score]
|
| 1119 |
+
|
| 1120 |
+
if overall_score >= 80:
|
| 1121 |
+
colors = ['#4CAF50', '#E0E0E0'] # Green
|
| 1122 |
+
elif overall_score >= 60:
|
| 1123 |
+
colors = ['#FFC107', '#E0E0E0'] # Yellow
|
| 1124 |
+
else:
|
| 1125 |
+
colors = ['#F44336', '#E0E0E0'] # Red
|
| 1126 |
+
|
| 1127 |
+
wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90,
|
| 1128 |
+
autopct='%1.1f%%', pctdistance=0.85)
|
| 1129 |
+
|
| 1130 |
+
# Central text
|
| 1131 |
+
status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'}
|
| 1132 |
+
status_color = status_colors.get(status, '#757575')
|
| 1133 |
+
|
| 1134 |
+
ax3.text(0, 0, f'{overall_score}/100\n{status}',
|
| 1135 |
+
ha='center', va='center', fontsize=14, fontweight='bold',
|
| 1136 |
+
color=status_color)
|
| 1137 |
+
ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold')
|
| 1138 |
+
|
| 1139 |
+
# 4. Issue distribution by type (left middle)
|
| 1140 |
+
ax4 = fig.add_subplot(gs[1, 0])
|
| 1141 |
+
|
| 1142 |
+
issues = validation_results.get('issues', {})
|
| 1143 |
+
issue_counts = {
|
| 1144 |
+
'Critical': len(issues.get('critical', [])),
|
| 1145 |
+
'Warnings': len(issues.get('warning', [])),
|
| 1146 |
+
'Informational': len(issues.get('info', []))
|
| 1147 |
+
}
|
| 1148 |
+
|
| 1149 |
+
if any(issue_counts.values()):
|
| 1150 |
+
colors = ['#F44336', '#FF9800', '#2196F3']
|
| 1151 |
+
bars = ax4.bar(issue_counts.keys(), issue_counts.values(),
|
| 1152 |
+
color=colors, edgecolor='black')
|
| 1153 |
+
|
| 1154 |
+
ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold')
|
| 1155 |
+
ax4.set_ylabel('Number of Issues', fontsize=10)
|
| 1156 |
+
ax4.tick_params(axis='x', rotation=45)
|
| 1157 |
+
ax4.grid(True, alpha=0.3, axis='y')
|
| 1158 |
+
|
| 1159 |
+
# Add values on bars
|
| 1160 |
+
for bar, value in zip(bars, issue_counts.values()):
|
| 1161 |
+
height = bar.get_height()
|
| 1162 |
+
ax4.text(bar.get_x() + bar.get_width()/2., height,
|
| 1163 |
+
f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
| 1164 |
+
else:
|
| 1165 |
+
ax4.text(0.5, 0.5, 'No issues detected',
|
| 1166 |
+
ha='center', va='center', transform=ax4.transAxes, fontsize=12)
|
| 1167 |
+
ax4.set_title('Data Issues', fontsize=12, fontweight='bold')
|
| 1168 |
+
|
| 1169 |
+
# 5. Detailed information (remaining cells)
|
| 1170 |
+
ax5 = fig.add_subplot(gs[1:, 1:])
|
| 1171 |
+
ax5.axis('off')
|
| 1172 |
+
|
| 1173 |
+
# Form text report
|
| 1174 |
+
report_text = []
|
| 1175 |
+
report_text.append("DETAILED REPORT:")
|
| 1176 |
+
report_text.append("=" * 40)
|
| 1177 |
+
|
| 1178 |
+
# Basic information
|
| 1179 |
+
report_text.append("\nBASIC INFORMATION:")
|
| 1180 |
+
report_text.append(f"• Overall score: {overall_score}/100")
|
| 1181 |
+
report_text.append(f"• Status: {status}")
|
| 1182 |
+
report_text.append(f"• Checks passed: {checks_passed}/{checks_total}")
|
| 1183 |
+
|
| 1184 |
+
# Check details
|
| 1185 |
+
if basic_checks:
|
| 1186 |
+
report_text.append("\nCHECK DETAILS:")
|
| 1187 |
+
for check_name, check_result in basic_checks.items():
|
| 1188 |
+
status_icon = "✓" if check_result.get('passed', False) else "✗"
|
| 1189 |
+
report_text.append(f"• {status_icon} {check_name}: {check_result.get('message', '')}")
|
| 1190 |
+
|
| 1191 |
+
# Issues
|
| 1192 |
+
if any(issue_counts.values()):
|
| 1193 |
+
report_text.append("\nDETECTED ISSUES:")
|
| 1194 |
+
|
| 1195 |
+
if issue_counts['Critical'] > 0:
|
| 1196 |
+
report_text.append("\nCRITICAL:")
|
| 1197 |
+
for issue in issues.get('critical', []):
|
| 1198 |
+
report_text.append(f" • {issue}")
|
| 1199 |
+
|
| 1200 |
+
if issue_counts['Warnings'] > 0:
|
| 1201 |
+
report_text.append("\nWARNINGS:")
|
| 1202 |
+
for issue in issues.get('warning', []):
|
| 1203 |
+
report_text.append(f" • {issue}")
|
| 1204 |
+
|
| 1205 |
+
if issue_counts['Informational'] > 0:
|
| 1206 |
+
report_text.append("\nINFORMATIONAL:")
|
| 1207 |
+
for issue in issues.get('info', []):
|
| 1208 |
+
report_text.append(f" • {issue}")
|
| 1209 |
+
|
| 1210 |
+
# Recommendations
|
| 1211 |
+
recommendations = validation_results.get('recommendations', [])
|
| 1212 |
+
if recommendations:
|
| 1213 |
+
report_text.append("\nRECOMMENDATIONS:")
|
| 1214 |
+
for i, rec in enumerate(recommendations, 1):
|
| 1215 |
+
report_text.append(f"{i}. {rec}")
|
| 1216 |
+
|
| 1217 |
+
ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes,
|
| 1218 |
+
fontsize=9, verticalalignment='top', fontfamily='monospace',
|
| 1219 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1))
|
| 1220 |
+
|
| 1221 |
+
plt.tight_layout()
|
| 1222 |
+
|
| 1223 |
+
# Save
|
| 1224 |
+
filepath = self._save_figure(fig, filename, "reports")
|
| 1225 |
+
self.plot_files['data_quality_report'] = filepath
|
| 1226 |
+
return filepath
|
| 1227 |
+
|
| 1228 |
+
except Exception as e:
|
| 1229 |
+
logger.error(f"Error creating data quality report: {e}")
|
| 1230 |
+
return None
|
| 1231 |
+
|
| 1232 |
+
# ============================================
|
| 1233 |
+
# METHODS FOR BATCH SAVING
|
| 1234 |
+
# ============================================
|
| 1235 |
+
|
| 1236 |
+
def save_all_preprocessing_plots(self) -> Dict[str, str]:
|
| 1237 |
+
"""
|
| 1238 |
+
Save all preprocessing plots from current session
|
| 1239 |
+
|
| 1240 |
+
Returns:
|
| 1241 |
+
--------
|
| 1242 |
+
Dict[str, str] : dictionary with paths to saved plots
|
| 1243 |
+
"""
|
| 1244 |
+
logger.info("Saving all preprocessing plots...")
|
| 1245 |
+
|
| 1246 |
+
plots_saved = {}
|
| 1247 |
+
|
| 1248 |
+
# Get all open figures
|
| 1249 |
+
figure_numbers = plt.get_fignums()
|
| 1250 |
+
|
| 1251 |
+
if not figure_numbers:
|
| 1252 |
+
logger.warning("No open plots to save")
|
| 1253 |
+
return plots_saved
|
| 1254 |
+
|
| 1255 |
+
# Save each plot
|
| 1256 |
+
for fig_num in figure_numbers:
|
| 1257 |
+
fig = plt.figure(fig_num)
|
| 1258 |
+
filename = f"preprocessing_plot_{fig_num}.png"
|
| 1259 |
+
filepath = self._save_figure(fig, filename, "preprocessing")
|
| 1260 |
+
if filepath:
|
| 1261 |
+
plots_saved[f"plot_{fig_num}"] = filepath
|
| 1262 |
+
|
| 1263 |
+
logger.info(f"Saved {len(plots_saved)} preprocessing plots")
|
| 1264 |
+
return plots_saved
|
| 1265 |
+
|
| 1266 |
+
def create_all_visualizations(
|
| 1267 |
+
self,
|
| 1268 |
+
data: pd.DataFrame,
|
| 1269 |
+
processed_data: pd.DataFrame = None,
|
| 1270 |
+
feature_importance: Dict = None,
|
| 1271 |
+
decomposition_result: Dict = None,
|
| 1272 |
+
validation_results: Dict = None,
|
| 1273 |
+
preprocessing_stages: Dict = None
|
| 1274 |
+
) -> Dict[str, str]:
|
| 1275 |
+
"""
|
| 1276 |
+
Create all visualisations in one call
|
| 1277 |
+
|
| 1278 |
+
Parameters:
|
| 1279 |
+
-----------
|
| 1280 |
+
data : pd.DataFrame
|
| 1281 |
+
Original data
|
| 1282 |
+
processed_data : pd.DataFrame, optional
|
| 1283 |
+
Processed data
|
| 1284 |
+
feature_importance : Dict, optional
|
| 1285 |
+
Feature importance
|
| 1286 |
+
decomposition_result : Dict, optional
|
| 1287 |
+
Decomposition results
|
| 1288 |
+
validation_results : Dict, optional
|
| 1289 |
+
Validation results
|
| 1290 |
+
preprocessing_stages : Dict, optional
|
| 1291 |
+
Preprocessing stages
|
| 1292 |
+
|
| 1293 |
+
Returns:
|
| 1294 |
+
--------
|
| 1295 |
+
Dict[str, str] : dictionary with paths to created plots
|
| 1296 |
+
"""
|
| 1297 |
+
logger.info("\n" + "="*80)
|
| 1298 |
+
logger.info("STARTING ALL VISUALISATIONS CREATION")
|
| 1299 |
+
logger.info("="*80)
|
| 1300 |
+
|
| 1301 |
+
result_files = {}
|
| 1302 |
+
|
| 1303 |
+
# 1. Summary dashboard
|
| 1304 |
+
if data is not None:
|
| 1305 |
+
logger.info("Creating summary dashboard...")
|
| 1306 |
+
summary_path = self.create_summary_dashboard(data, preprocessing_stages)
|
| 1307 |
+
if summary_path:
|
| 1308 |
+
result_files['summary'] = summary_path
|
| 1309 |
+
|
| 1310 |
+
# 2. Correlation heatmaps
|
| 1311 |
+
if data is not None:
|
| 1312 |
+
logger.info("Creating correlation heatmaps...")
|
| 1313 |
+
main_corr, target_corr = self.create_correlation_heatmap(data)
|
| 1314 |
+
if main_corr:
|
| 1315 |
+
result_files['correlation_main'] = main_corr
|
| 1316 |
+
if target_corr:
|
| 1317 |
+
result_files['correlation_target'] = target_corr
|
| 1318 |
+
|
| 1319 |
+
# 3. Distribution comparison
|
| 1320 |
+
if data is not None and processed_data is not None:
|
| 1321 |
+
logger.info("Creating distribution comparison...")
|
| 1322 |
+
dist_path = self.create_distribution_comparison(data, processed_data)
|
| 1323 |
+
if dist_path:
|
| 1324 |
+
result_files['distribution'] = dist_path
|
| 1325 |
+
|
| 1326 |
+
# 4. Feature importance
|
| 1327 |
+
if feature_importance:
|
| 1328 |
+
logger.info("Creating feature importance plot...")
|
| 1329 |
+
feat_path = self.create_feature_importance_plot(feature_importance)
|
| 1330 |
+
if feat_path:
|
| 1331 |
+
result_files['feature_importance'] = feat_path
|
| 1332 |
+
|
| 1333 |
+
# 5. Time series decomposition
|
| 1334 |
+
if decomposition_result:
|
| 1335 |
+
logger.info("Creating time series decomposition...")
|
| 1336 |
+
decomp_path = self.create_time_series_decomposition_plot(decomposition_result)
|
| 1337 |
+
if decomp_path:
|
| 1338 |
+
result_files['decomposition'] = decomp_path
|
| 1339 |
+
|
| 1340 |
+
# 6. Data quality report
|
| 1341 |
+
if validation_results:
|
| 1342 |
+
logger.info("Creating data quality report...")
|
| 1343 |
+
quality_path = self.create_data_quality_report(validation_results)
|
| 1344 |
+
if quality_path:
|
| 1345 |
+
result_files['quality_report'] = quality_path
|
| 1346 |
+
|
| 1347 |
+
# Save information about all plots
|
| 1348 |
+
self.save_plots_info()
|
| 1349 |
+
|
| 1350 |
+
logger.info("\n" + "="*80)
|
| 1351 |
+
logger.info("VISUALISATIONS SUCCESSFULLY CREATED")
|
| 1352 |
+
logger.info("="*80)
|
| 1353 |
+
|
| 1354 |
+
for plot_name, plot_path in result_files.items():
|
| 1355 |
+
if plot_path:
|
| 1356 |
+
logger.info(f"✓ {plot_name}: {plot_path}")
|
| 1357 |
+
|
| 1358 |
+
return result_files
|
| 1359 |
+
|
| 1360 |
+
def get_all_plots(self) -> Dict:
|
| 1361 |
+
"""Get information about all created plots"""
|
| 1362 |
+
return self.plot_files
|
| 1363 |
+
|
| 1364 |
+
def save_plots_info(self, filename: str = "plots_info.json") -> None:
|
| 1365 |
+
"""Save plot information to JSON file"""
|
| 1366 |
+
try:
|
| 1367 |
+
plots_info = {
|
| 1368 |
+
'total_plots': len(self.plot_files),
|
| 1369 |
+
'plots': self.plot_files,
|
| 1370 |
+
'directories': {
|
| 1371 |
+
'correlations': self.correlations_dir,
|
| 1372 |
+
'distributions': self.distributions_dir,
|
| 1373 |
+
'features': self.features_dir,
|
| 1374 |
+
'time_series': self.time_series_dir,
|
| 1375 |
+
'preprocessing': self.preprocessing_dir,
|
| 1376 |
+
'summary': self.summary_dir,
|
| 1377 |
+
'reports': self.reports_dir
|
| 1378 |
+
},
|
| 1379 |
+
'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 1380 |
+
'config': {
|
| 1381 |
+
'target_column': self.config.target_column,
|
| 1382 |
+
'results_dir': self.config.results_dir
|
| 1383 |
+
}
|
| 1384 |
+
}
|
| 1385 |
+
|
| 1386 |
+
filepath = os.path.join(self.reports_dir, filename)
|
| 1387 |
+
|
| 1388 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 1389 |
+
json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str)
|
| 1390 |
+
|
| 1391 |
+
logger.info(f"✓ Plot information saved: {filepath}")
|
| 1392 |
+
|
| 1393 |
+
except Exception as e:
|
| 1394 |
+
logger.error(f"✗ Error saving plot information: {e}")
|
| 1395 |
+
|
| 1396 |
+
def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]:
|
| 1397 |
+
"""
|
| 1398 |
+
Move existing plots from specified directory to structured folders
|
| 1399 |
+
|
| 1400 |
+
Parameters:
|
| 1401 |
+
-----------
|
| 1402 |
+
source_dir : str, optional
|
| 1403 |
+
Directory with existing plots
|
| 1404 |
+
|
| 1405 |
+
Returns:
|
| 1406 |
+
--------
|
| 1407 |
+
Dict[str, str] : dictionary with information about moved files
|
| 1408 |
+
"""
|
| 1409 |
+
if source_dir is None:
|
| 1410 |
+
source_dir = self.plots_dir
|
| 1411 |
+
|
| 1412 |
+
if not os.path.exists(source_dir):
|
| 1413 |
+
logger.warning(f"Source directory doesn't exist: {source_dir}")
|
| 1414 |
+
return {}
|
| 1415 |
+
|
| 1416 |
+
# File to folder mapping
|
| 1417 |
+
file_to_folder_map = {
|
| 1418 |
+
# Time series
|
| 1419 |
+
'data_split.png': 'time_series',
|
| 1420 |
+
'stationarity_raskhodvoda.png': 'time_series',
|
| 1421 |
+
'stationarity_analysis.png': 'time_series',
|
| 1422 |
+
'temporal_outliers.png': 'time_series',
|
| 1423 |
+
|
| 1424 |
+
# Correlations
|
| 1425 |
+
'feature_selection_correlation.png': 'correlations',
|
| 1426 |
+
|
| 1427 |
+
# Preprocessing
|
| 1428 |
+
'missing_values_analysis.png': 'preprocessing',
|
| 1429 |
+
'outlier_handling_results.png': 'preprocessing',
|
| 1430 |
+
'outliers_analysis.png': 'preprocessing',
|
| 1431 |
+
'scaling_results.png': 'preprocessing',
|
| 1432 |
+
|
| 1433 |
+
# Default
|
| 1434 |
+
'default': 'summary'
|
| 1435 |
+
}
|
| 1436 |
+
|
| 1437 |
+
moved_files = {}
|
| 1438 |
+
|
| 1439 |
+
for filename in os.listdir(source_dir):
|
| 1440 |
+
if filename.endswith('.png'):
|
| 1441 |
+
source_path = os.path.join(source_dir, filename)
|
| 1442 |
+
|
| 1443 |
+
# Determine destination folder
|
| 1444 |
+
target_folder = file_to_folder_map.get(filename, file_to_folder_map['default'])
|
| 1445 |
+
target_dir = os.path.join(self.plots_dir, target_folder)
|
| 1446 |
+
|
| 1447 |
+
# Create destination folder if doesn't exist
|
| 1448 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 1449 |
+
|
| 1450 |
+
# Target path
|
| 1451 |
+
target_path = os.path.join(target_dir, filename)
|
| 1452 |
+
|
| 1453 |
+
try:
|
| 1454 |
+
# Move file
|
| 1455 |
+
os.rename(source_path, target_path)
|
| 1456 |
+
moved_files[filename] = target_path
|
| 1457 |
+
logger.info(f"Moved: {filename} -> {target_folder}/")
|
| 1458 |
+
except Exception as e:
|
| 1459 |
+
logger.error(f"Error moving {filename}: {e}")
|
| 1460 |
+
|
| 1461 |
+
logger.info(f"Moved {len(moved_files)} files")
|
| 1462 |
+
return moved_files
|