first commit
Browse files- .gitattributes +5 -0
- .gitignore +47 -0
- FPA_FOD_20170508.sqlite +3 -0
- README.md +156 -3
- config/__init__.py +1 -0
- config/config.py +181 -0
- data/processed/fires_features.parquet +3 -0
- data/processed/fires_processed.parquet +3 -0
- data/processed/fires_raw.parquet +3 -0
- data/processed/test.parquet +3 -0
- data/processed/train.parquet +3 -0
- models/best_params.json +11 -0
- models/model_metadata.joblib +3 -0
- models/wildfire_model.txt +3 -0
- reports/figures/cause_by_size.png +0 -0
- reports/figures/class_distribution.png +0 -0
- reports/figures/classification_metrics.png +0 -0
- reports/figures/confusion_matrix.png +0 -0
- reports/figures/feature_importance.csv +29 -0
- reports/figures/geographic_distribution.png +3 -0
- reports/figures/missing_values.png +0 -0
- reports/figures/prediction_distribution.png +0 -0
- reports/figures/shap_importance.png +0 -0
- reports/figures/shap_importance_summary.png +3 -0
- reports/figures/temporal_patterns.png +3 -0
- requirements.txt +29 -0
- run_pipeline.py +91 -0
- scripts/01_extract_data.py +163 -0
- scripts/02_eda.py +345 -0
- scripts/03_preprocess.py +280 -0
- scripts/04_feature_engineering.py +290 -0
- scripts/05_train_model.py +370 -0
- scripts/06_evaluate.py +438 -0
- scripts/07_predict.py +312 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
FPA_FOD_20170508.sqlite filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
models/wildfire_model.txt filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
reports/figures/geographic_distribution.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
reports/figures/shap_importance_summary.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
reports/figures/temporal_patterns.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Virtual environments
|
| 7 |
+
venv/
|
| 8 |
+
env/
|
| 9 |
+
.venv/
|
| 10 |
+
.env/
|
| 11 |
+
|
| 12 |
+
# IDE settings
|
| 13 |
+
.vscode/
|
| 14 |
+
.idea/
|
| 15 |
+
*.swp
|
| 16 |
+
*.swo
|
| 17 |
+
|
| 18 |
+
# Jupyter Notebooks
|
| 19 |
+
.ipynb_checkpoints/
|
| 20 |
+
|
| 21 |
+
# Data files
|
| 22 |
+
*.sqlite
|
| 23 |
+
*.db
|
| 24 |
+
*.parquet
|
| 25 |
+
data/processed/
|
| 26 |
+
data/raw/
|
| 27 |
+
|
| 28 |
+
# Model artifacts
|
| 29 |
+
models/*.txt
|
| 30 |
+
models/*.joblib
|
| 31 |
+
models/*.json
|
| 32 |
+
models/*.pkl
|
| 33 |
+
|
| 34 |
+
# Reports and figures
|
| 35 |
+
reports/figures/*.png
|
| 36 |
+
reports/figures/*.csv
|
| 37 |
+
|
| 38 |
+
# OS files
|
| 39 |
+
.DS_Store
|
| 40 |
+
Thumbs.db
|
| 41 |
+
|
| 42 |
+
# Logs
|
| 43 |
+
*.log
|
| 44 |
+
|
| 45 |
+
# Environment variables
|
| 46 |
+
.env
|
| 47 |
+
.env.local
|
FPA_FOD_20170508.sqlite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f04b23a24989770ce05fa354662b03e597ad164ddf5b7932b8a53b46d0ed428b
|
| 3 |
+
size 795785216
|
README.md
CHANGED
|
@@ -1,3 +1,156 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wildfire Size Classification Project
|
| 2 |
+
|
| 3 |
+
Predicting wildfire size classes using machine learning on the FPA FOD (Fire Program Analysis Fire-Occurrence Database) containing 1.88 million US wildfire records from 1992-2015.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
This project builds an **ordinal classification model** to predict fire size categories:
|
| 8 |
+
- **Small** (0-9.9 acres): Original classes A + B
|
| 9 |
+
- **Medium** (10-299 acres): Original classes C + D
|
| 10 |
+
- **Large** (300+ acres): Original classes E + F + G
|
| 11 |
+
|
| 12 |
+
### Key Features
|
| 13 |
+
- **Ordinal-aware classification**: Leverages the natural ordering of fire size classes
|
| 14 |
+
- **Geospatial features**: Coordinate clustering, regional binning, distance metrics
|
| 15 |
+
- **Temporal features**: Cyclical encoding of month/day, fire season indicators
|
| 16 |
+
- **Class imbalance handling**: Balanced class weights for rare large fire events
|
| 17 |
+
- **Interpretable results**: SHAP feature importance analysis
|
| 18 |
+
|
| 19 |
+
## Project Structure
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
wildfires/
|
| 23 |
+
├── config/
|
| 24 |
+
│ ├── __init__.py # Package init
|
| 25 |
+
│ └── config.py # Configuration settings
|
| 26 |
+
├── data/
|
| 27 |
+
│ └── processed/ # Processed parquet files (train/test splits)
|
| 28 |
+
├── models/ # Saved model artifacts
|
| 29 |
+
│ ├── best_params.json # Tuned hyperparameters
|
| 30 |
+
│ ├── model_metadata.joblib # Feature names and metrics
|
| 31 |
+
│ └── wildfire_model.txt # Trained LightGBM model
|
| 32 |
+
├── reports/
|
| 33 |
+
│ └── figures/ # Visualizations and metrics
|
| 34 |
+
├── scripts/
|
| 35 |
+
│ ├── 01_extract_data.py # Extract SQLite → Parquet
|
| 36 |
+
│ ├── 02_eda.py # Exploratory data analysis
|
| 37 |
+
│ ├── 03_preprocess.py # Data preprocessing
|
| 38 |
+
│ ├── 04_feature_engineering.py # Feature creation
|
| 39 |
+
│ ├── 05_train_model.py # Model training
|
| 40 |
+
│ ├── 06_evaluate.py # Model evaluation
|
| 41 |
+
│ └── 07_predict.py # Prediction pipeline
|
| 42 |
+
├── run_pipeline.py # Run full or partial pipeline
|
| 43 |
+
├── requirements.txt # Dependencies
|
| 44 |
+
├── .gitignore # Git ignore rules
|
| 45 |
+
└── README.md
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Getting Started
|
| 49 |
+
|
| 50 |
+
### Prerequisites
|
| 51 |
+
- Python 3.9+
|
| 52 |
+
- SQLite database file (`FPA_FOD_20170508.sqlite`)
|
| 53 |
+
|
| 54 |
+
### Installation
|
| 55 |
+
|
| 56 |
+
1. Clone/download the repository
|
| 57 |
+
2. Create a virtual environment:
|
| 58 |
+
```bash
|
| 59 |
+
python -m venv venv
|
| 60 |
+
venv\Scripts\activate # Windows
|
| 61 |
+
# source venv/bin/activate # Linux/Mac
|
| 62 |
+
```
|
| 63 |
+
3. Install dependencies:
|
| 64 |
+
```bash
|
| 65 |
+
pip install -r requirements.txt
|
| 66 |
+
```
|
| 67 |
+
4. Place the SQLite database file in the project root
|
| 68 |
+
|
| 69 |
+
### Running the Pipeline
|
| 70 |
+
|
| 71 |
+
**Using the pipeline runner (recommended):**
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# Run full pipeline
|
| 75 |
+
python run_pipeline.py
|
| 76 |
+
|
| 77 |
+
# Skip EDA step
|
| 78 |
+
python run_pipeline.py --skip-eda
|
| 79 |
+
|
| 80 |
+
# Run with hyperparameter tuning
|
| 81 |
+
python run_pipeline.py --tune
|
| 82 |
+
|
| 83 |
+
# Resume from a specific step (1-7)
|
| 84 |
+
python run_pipeline.py --from-step 5
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
**Or execute scripts individually:**
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
# 1. Extract data from SQLite
|
| 91 |
+
python scripts/01_extract_data.py
|
| 92 |
+
|
| 93 |
+
# 2. Exploratory data analysis (generates plots)
|
| 94 |
+
python scripts/02_eda.py
|
| 95 |
+
|
| 96 |
+
# 3. Preprocess data
|
| 97 |
+
python scripts/03_preprocess.py
|
| 98 |
+
|
| 99 |
+
# 4. Feature engineering
|
| 100 |
+
python scripts/04_feature_engineering.py
|
| 101 |
+
|
| 102 |
+
# 5. Train model (add --tune for hyperparameter tuning)
|
| 103 |
+
python scripts/05_train_model.py
|
| 104 |
+
# python scripts/05_train_model.py --tune # With Optuna tuning
|
| 105 |
+
|
| 106 |
+
# 6. Evaluate model
|
| 107 |
+
python scripts/06_evaluate.py
|
| 108 |
+
|
| 109 |
+
# 7. Make predictions
|
| 110 |
+
python scripts/07_predict.py --lat 34.05 --lon -118.24 --state CA --cause "Lightning"
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## Model Details
|
| 114 |
+
|
| 115 |
+
### Features Used
|
| 116 |
+
- **Temporal**: Month, day of week, season, fire season indicator (cyclically encoded)
|
| 117 |
+
- **Geospatial**: Lat/lon coordinates, regional clusters (K-means), coordinate bins
|
| 118 |
+
- **Categorical**: State, fire cause, reporting agency, land owner
|
| 119 |
+
- **Year**: Fire year, years since 1992
|
| 120 |
+
|
| 121 |
+
### Algorithm
|
| 122 |
+
- **LightGBM** gradient boosting for multi-class classification
|
| 123 |
+
- Class weights to handle imbalanced data (~90% small fires)
|
| 124 |
+
- Linear weighted Cohen's Kappa for ordinal evaluation
|
| 125 |
+
|
| 126 |
+
### Expected Performance
|
| 127 |
+
- Balanced Accuracy: ~65-75%
|
| 128 |
+
- Macro F1 Score: ~0.45-0.55
|
| 129 |
+
- Large fire detection is challenging due to class imbalance
|
| 130 |
+
|
| 131 |
+
## Evaluation Metrics
|
| 132 |
+
|
| 133 |
+
For ordinal classification, we prioritize:
|
| 134 |
+
- **Macro F1**: Equal importance to all classes
|
| 135 |
+
- **Balanced Accuracy**: Accounts for class imbalance
|
| 136 |
+
- **Linear Weighted Kappa**: Penalizes predictions far from true class
|
| 137 |
+
|
| 138 |
+
## Output Files
|
| 139 |
+
|
| 140 |
+
After running the pipeline:
|
| 141 |
+
- `data/processed/`: Parquet files for train/test splits
|
| 142 |
+
- `models/wildfire_model.txt`: Trained LightGBM model
|
| 143 |
+
- `models/model_metadata.joblib`: Feature names and metrics
|
| 144 |
+
- `reports/figures/`: Visualizations (confusion matrix, SHAP plots, etc.)
|
| 145 |
+
|
| 146 |
+
## Data Source
|
| 147 |
+
|
| 148 |
+
**Fire Program Analysis Fire-Occurrence Database (FPA FOD)**
|
| 149 |
+
- 1.88 million geo-referenced wildfire records
|
| 150 |
+
- Period: 1992-2015
|
| 151 |
+
- 140 million acres burned
|
| 152 |
+
- Source: US federal, state, and local fire organizations
|
| 153 |
+
|
| 154 |
+
## License
|
| 155 |
+
|
| 156 |
+
This project uses publicly available government data.
|
config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .config import *
|
config/config.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the Wildfire Size Classification project.
|
| 3 |
+
|
| 4 |
+
Target: Ordinal classification of fire size into 3 classes:
|
| 5 |
+
- 0 (Small): Classes A + B (0 - 9.9 acres)
|
| 6 |
+
- 1 (Medium): Classes C + D (10 - 299 acres)
|
| 7 |
+
- 2 (Large): Classes E + F + G (300+ acres)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# =============================================================================
|
| 13 |
+
# PATHS
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
# Project root
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 18 |
+
|
| 19 |
+
# Data paths
|
| 20 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 21 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 22 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 23 |
+
|
| 24 |
+
# SQLite database path (adjust filename if different)
|
| 25 |
+
SQLITE_DB_PATH = PROJECT_ROOT / "FPA_FOD_20170508.sqlite"
|
| 26 |
+
|
| 27 |
+
# Output paths
|
| 28 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 29 |
+
REPORTS_DIR = PROJECT_ROOT / "reports"
|
| 30 |
+
FIGURES_DIR = REPORTS_DIR / "figures"
|
| 31 |
+
|
| 32 |
+
# Processed data files
|
| 33 |
+
RAW_PARQUET = PROCESSED_DATA_DIR / "fires_raw.parquet"
|
| 34 |
+
PROCESSED_PARQUET = PROCESSED_DATA_DIR / "fires_processed.parquet"
|
| 35 |
+
FEATURES_PARQUET = PROCESSED_DATA_DIR / "fires_features.parquet"
|
| 36 |
+
TRAIN_PARQUET = PROCESSED_DATA_DIR / "train.parquet"
|
| 37 |
+
TEST_PARQUET = PROCESSED_DATA_DIR / "test.parquet"
|
| 38 |
+
|
| 39 |
+
# =============================================================================
|
| 40 |
+
# TARGET VARIABLE CONFIGURATION
|
| 41 |
+
# =============================================================================
|
| 42 |
+
|
| 43 |
+
# Original fire size classes mapping to ordinal target
|
| 44 |
+
# A (0-0.25 acres), B (0.26-9.9 acres) -> 0 (Small)
|
| 45 |
+
# C (10-99.9 acres), D (100-299 acres) -> 1 (Medium)
|
| 46 |
+
# E (300-999 acres), F (1000-4999 acres), G (5000+ acres) -> 2 (Large)
|
| 47 |
+
|
| 48 |
+
FIRE_SIZE_CLASS_MAPPING = {
|
| 49 |
+
'A': 0, 'B': 0, # Small
|
| 50 |
+
'C': 1, 'D': 1, # Medium
|
| 51 |
+
'E': 2, 'F': 2, 'G': 2 # Large
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
TARGET_CLASS_NAMES = ['Small', 'Medium', 'Large']
|
| 55 |
+
TARGET_COLUMN = 'fire_size_ordinal'
|
| 56 |
+
ORIGINAL_TARGET_COLUMN = 'FIRE_SIZE_CLASS'
|
| 57 |
+
|
| 58 |
+
# =============================================================================
|
| 59 |
+
# FEATURE CONFIGURATION
|
| 60 |
+
# =============================================================================
|
| 61 |
+
|
| 62 |
+
# Columns to drop (IDs, redundant info, text fields)
|
| 63 |
+
COLUMNS_TO_DROP = [
|
| 64 |
+
'FOD_ID', 'FPA_ID', 'SOURCE_SYSTEM_TYPE', 'SOURCE_SYSTEM',
|
| 65 |
+
'NWCG_REPORTING_UNIT_ID', 'NWCG_REPORTING_UNIT_NAME',
|
| 66 |
+
'SOURCE_REPORTING_UNIT', 'SOURCE_REPORTING_UNIT_NAME',
|
| 67 |
+
'LOCAL_FIRE_REPORT_ID', 'LOCAL_INCIDENT_ID',
|
| 68 |
+
'FIRE_CODE', 'FIRE_NAME',
|
| 69 |
+
'ICS_209_INCIDENT_NUMBER', 'ICS_209_NAME',
|
| 70 |
+
'MTBS_ID', 'MTBS_FIRE_NAME', 'COMPLEX_NAME',
|
| 71 |
+
'DISCOVERY_DATE', 'DISCOVERY_TIME',
|
| 72 |
+
'CONT_DATE', 'CONT_DOY', 'CONT_TIME',
|
| 73 |
+
'FIPS_CODE', 'FIPS_NAME',
|
| 74 |
+
'FIRE_SIZE', # Don't use actual size as feature - it's what we're predicting
|
| 75 |
+
'FIRE_SIZE_CLASS', # Original target
|
| 76 |
+
'Shape' # Geometry column if present
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
# Categorical features to encode
|
| 80 |
+
CATEGORICAL_FEATURES = [
|
| 81 |
+
'NWCG_REPORTING_AGENCY',
|
| 82 |
+
'STAT_CAUSE_DESCR',
|
| 83 |
+
'STATE',
|
| 84 |
+
'OWNER_DESCR'
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Numerical features (after feature engineering)
|
| 88 |
+
NUMERICAL_FEATURES = [
|
| 89 |
+
'LATITUDE',
|
| 90 |
+
'LONGITUDE',
|
| 91 |
+
'DISCOVERY_DOY',
|
| 92 |
+
'FIRE_YEAR'
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
# Temporal features to create
|
| 96 |
+
TEMPORAL_FEATURES = [
|
| 97 |
+
'month',
|
| 98 |
+
'season',
|
| 99 |
+
'day_of_week',
|
| 100 |
+
'is_weekend'
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
# Geospatial features to create
|
| 104 |
+
GEOSPATIAL_FEATURES = [
|
| 105 |
+
'lat_bin',
|
| 106 |
+
'lon_bin',
|
| 107 |
+
'geo_cluster',
|
| 108 |
+
'lat_lon_interaction'
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# MODEL CONFIGURATION
|
| 113 |
+
# =============================================================================
|
| 114 |
+
|
| 115 |
+
# Random seed for reproducibility
|
| 116 |
+
RANDOM_STATE = 42
|
| 117 |
+
|
| 118 |
+
# Train/test split ratio
|
| 119 |
+
TEST_SIZE = 0.2
|
| 120 |
+
|
| 121 |
+
# Cross-validation folds
|
| 122 |
+
N_FOLDS = 5
|
| 123 |
+
|
| 124 |
+
# Class weights for imbalanced data (will be computed dynamically)
|
| 125 |
+
USE_CLASS_WEIGHTS = True
|
| 126 |
+
|
| 127 |
+
# LightGBM base parameters for ordinal classification
|
| 128 |
+
LIGHTGBM_PARAMS = {
|
| 129 |
+
'objective': 'multiclass',
|
| 130 |
+
'num_class': 3,
|
| 131 |
+
'metric': 'multi_logloss',
|
| 132 |
+
'boosting_type': 'gbdt',
|
| 133 |
+
'verbosity': -1,
|
| 134 |
+
'random_state': RANDOM_STATE,
|
| 135 |
+
'n_jobs': -1
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Optuna hyperparameter search space
|
| 139 |
+
OPTUNA_SEARCH_SPACE = {
|
| 140 |
+
'n_estimators': (100, 1000),
|
| 141 |
+
'max_depth': (3, 12),
|
| 142 |
+
'learning_rate': (0.01, 0.3),
|
| 143 |
+
'num_leaves': (20, 150),
|
| 144 |
+
'min_child_samples': (10, 100),
|
| 145 |
+
'subsample': (0.6, 1.0),
|
| 146 |
+
'colsample_bytree': (0.6, 1.0),
|
| 147 |
+
'reg_alpha': (0.0, 1.0),
|
| 148 |
+
'reg_lambda': (0.0, 1.0)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# Number of Optuna trials
|
| 152 |
+
N_OPTUNA_TRIALS = 50
|
| 153 |
+
|
| 154 |
+
# =============================================================================
|
| 155 |
+
# GEOSPATIAL CLUSTERING CONFIGURATION
|
| 156 |
+
# =============================================================================
|
| 157 |
+
|
| 158 |
+
# Number of clusters for geographic regions
|
| 159 |
+
N_GEO_CLUSTERS = 20
|
| 160 |
+
|
| 161 |
+
# Latitude/Longitude binning
|
| 162 |
+
LAT_BINS = 10
|
| 163 |
+
LON_BINS = 10
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# EVALUATION METRICS
|
| 167 |
+
# =============================================================================
|
| 168 |
+
|
| 169 |
+
# Primary metric for model selection
|
| 170 |
+
PRIMARY_METRIC = 'macro_f1'
|
| 171 |
+
|
| 172 |
+
# All metrics to compute
|
| 173 |
+
EVALUATION_METRICS = [
|
| 174 |
+
'accuracy',
|
| 175 |
+
'balanced_accuracy',
|
| 176 |
+
'macro_f1',
|
| 177 |
+
'weighted_f1',
|
| 178 |
+
'cohen_kappa',
|
| 179 |
+
'macro_precision',
|
| 180 |
+
'macro_recall'
|
| 181 |
+
]
|
data/processed/fires_features.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6926ded384d61d2cd531853ce79785befc12cbb9d827ebf89e5f334aefef5c57
|
| 3 |
+
size 116607705
|
data/processed/fires_processed.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:245d6fccc065f5981420244d0aa7a6d0cd12cbe9f89753852b12eba03978c75f
|
| 3 |
+
size 26619137
|
data/processed/fires_raw.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:841e692f7f4044513d96c762f726a276e165a009a7f6ecb8d05547d8cee59cf0
|
| 3 |
+
size 127864657
|
data/processed/test.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2c42ac67caff94f6882dc3f7f00b3e5a30f37d19672d4fa5bd279795622c7df
|
| 3 |
+
size 23840600
|
data/processed/train.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb2678f4b14baf3863f32818b4030253ac731587c1f3451579dc38c1dad691d2
|
| 3 |
+
size 93657151
|
models/best_params.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"n_estimators": 516,
|
| 3 |
+
"max_depth": 9,
|
| 4 |
+
"learning_rate": 0.21074359142553023,
|
| 5 |
+
"num_leaves": 131,
|
| 6 |
+
"min_child_samples": 70,
|
| 7 |
+
"subsample": 0.7796691167092181,
|
| 8 |
+
"colsample_bytree": 0.7018466944576152,
|
| 9 |
+
"reg_alpha": 0.52062185793523,
|
| 10 |
+
"reg_lambda": 0.3425483694161869
|
| 11 |
+
}
|
models/model_metadata.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37b351a57637f9c738d07be623d8cfbfe9a4a996f1e958c9e63883e023527925
|
| 3 |
+
size 958
|
models/wildfire_model.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b1a077debe29e736c3382e934aa8b6e5c7cf235e8707d680bb8664ab5317d99
|
| 3 |
+
size 21722535
|
reports/figures/cause_by_size.png
ADDED
|
reports/figures/class_distribution.png
ADDED
|
reports/figures/classification_metrics.png
ADDED
|
reports/figures/confusion_matrix.png
ADDED
|
reports/figures/feature_importance.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
feature,importance
|
| 2 |
+
day_of_week,0.03474360040175998
|
| 3 |
+
OWNER_DESCR_encoded,0.035002645531480546
|
| 4 |
+
STATE_encoded,0.04257236407411825
|
| 5 |
+
is_weekend,0.042745448253639684
|
| 6 |
+
lat_squared,0.0492982200647884
|
| 7 |
+
month_cos,0.04941415345320393
|
| 8 |
+
dist_from_center,0.04974556334097185
|
| 9 |
+
years_since_1992,0.05059611732831027
|
| 10 |
+
lat_bin,0.05266350693152481
|
| 11 |
+
month_sin,0.05535090135037826
|
| 12 |
+
NWCG_REPORTING_AGENCY_encoded,0.06568755534754865
|
| 13 |
+
dow_sin,0.06723634255462703
|
| 14 |
+
doy_sin,0.06948528640895639
|
| 15 |
+
year_normalized,0.07218163313544536
|
| 16 |
+
month,0.07541548211445999
|
| 17 |
+
lat_lon_interaction,0.08870144324496855
|
| 18 |
+
doy_cos,0.09805861682638295
|
| 19 |
+
lon_bin,0.0997225819136466
|
| 20 |
+
LONGITUDE,0.10201343368889876
|
| 21 |
+
season,0.10830839963123856
|
| 22 |
+
FIRE_YEAR,0.10860962139407908
|
| 23 |
+
is_fire_season,0.12894166850207692
|
| 24 |
+
lon_squared,0.13777922755612584
|
| 25 |
+
DISCOVERY_DOY,0.14014555480776913
|
| 26 |
+
geo_cluster,0.1491778952047764
|
| 27 |
+
STAT_CAUSE_DESCR_encoded,0.15889478851167435
|
| 28 |
+
LATITUDE,0.17045380361507326
|
| 29 |
+
dow_cos,0.23151901964397537
|
reports/figures/geographic_distribution.png
ADDED
|
Git LFS Details
|
reports/figures/missing_values.png
ADDED
|
reports/figures/prediction_distribution.png
ADDED
|
reports/figures/shap_importance.png
ADDED
|
reports/figures/shap_importance_summary.png
ADDED
|
Git LFS Details
|
reports/figures/temporal_patterns.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core data processing
|
| 2 |
+
pandas>=2.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
pyarrow>=14.0.0
|
| 5 |
+
|
| 6 |
+
# Machine learning
|
| 7 |
+
scikit-learn>=1.3.0
|
| 8 |
+
lightgbm>=4.0.0
|
| 9 |
+
xgboost>=2.0.0
|
| 10 |
+
imbalanced-learn>=0.11.0
|
| 11 |
+
|
| 12 |
+
# Hyperparameter tuning
|
| 13 |
+
optuna>=3.4.0
|
| 14 |
+
|
| 15 |
+
# Model interpretability
|
| 16 |
+
shap>=0.43.0
|
| 17 |
+
|
| 18 |
+
# Visualization
|
| 19 |
+
matplotlib>=3.7.0
|
| 20 |
+
seaborn>=0.12.0
|
| 21 |
+
|
| 22 |
+
# Geospatial (optional clustering)
|
| 23 |
+
scikit-learn # KMeans for coordinate clustering
|
| 24 |
+
|
| 25 |
+
# Progress bars
|
| 26 |
+
tqdm>=4.66.0
|
| 27 |
+
|
| 28 |
+
# Joblib for model persistence
|
| 29 |
+
joblib>=1.3.0
|
run_pipeline.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main Pipeline Runner
|
| 3 |
+
|
| 4 |
+
Runs the complete ML pipeline from data extraction to evaluation.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python run_pipeline.py # Run full pipeline
|
| 8 |
+
python run_pipeline.py --skip-eda # Skip EDA step
|
| 9 |
+
python run_pipeline.py --tune # Include hyperparameter tuning
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import subprocess
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def run_script(script_name: str, extra_args: list = None) -> bool:
|
| 19 |
+
"""Run a Python script and return success status."""
|
| 20 |
+
script_path = Path(__file__).parent / "scripts" / script_name
|
| 21 |
+
|
| 22 |
+
if not script_path.exists():
|
| 23 |
+
print(f"ERROR: Script not found: {script_path}")
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
cmd = [sys.executable, str(script_path)]
|
| 27 |
+
if extra_args:
|
| 28 |
+
cmd.extend(extra_args)
|
| 29 |
+
|
| 30 |
+
print(f"\n{'='*60}")
|
| 31 |
+
print(f"Running: {script_name}")
|
| 32 |
+
print(f"{'='*60}\n")
|
| 33 |
+
|
| 34 |
+
result = subprocess.run(cmd, cwd=str(Path(__file__).parent))
|
| 35 |
+
|
| 36 |
+
if result.returncode != 0:
|
| 37 |
+
print(f"\nERROR: {script_name} failed with return code {result.returncode}")
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser(description="Run wildfire ML pipeline")
|
| 45 |
+
parser.add_argument("--skip-eda", action="store_true", help="Skip EDA step")
|
| 46 |
+
parser.add_argument("--tune", action="store_true", help="Run hyperparameter tuning")
|
| 47 |
+
parser.add_argument("--from-step", type=int, default=1, help="Start from step number (1-7)")
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
|
| 50 |
+
print("\n" + "="*60)
|
| 51 |
+
print("WILDFIRE SIZE CLASSIFICATION PIPELINE")
|
| 52 |
+
print("="*60)
|
| 53 |
+
|
| 54 |
+
steps = [
|
| 55 |
+
("01_extract_data.py", []),
|
| 56 |
+
("02_eda.py", []),
|
| 57 |
+
("03_preprocess.py", []),
|
| 58 |
+
("04_feature_engineering.py", []),
|
| 59 |
+
("05_train_model.py", ["--tune"] if args.tune else []),
|
| 60 |
+
("06_evaluate.py", []),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
for i, (script, extra_args) in enumerate(steps, 1):
|
| 64 |
+
if i < args.from_step:
|
| 65 |
+
print(f"\nSkipping step {i}: {script}")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
if args.skip_eda and "eda" in script:
|
| 69 |
+
print(f"\nSkipping EDA step: {script}")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
success = run_script(script, extra_args)
|
| 73 |
+
if not success:
|
| 74 |
+
print(f"\nPipeline failed at step {i}: {script}")
|
| 75 |
+
sys.exit(1)
|
| 76 |
+
|
| 77 |
+
print("\n" + "="*60)
|
| 78 |
+
print("✓ PIPELINE COMPLETE!")
|
| 79 |
+
print("="*60)
|
| 80 |
+
print("\nOutputs:")
|
| 81 |
+
print(" - Model: models/wildfire_model.txt")
|
| 82 |
+
print(" - Figures: reports/figures/")
|
| 83 |
+
print(" - Data: data/processed/")
|
| 84 |
+
print("\nNext steps:")
|
| 85 |
+
print(" - Review figures in reports/figures/")
|
| 86 |
+
print(" - Make predictions: python scripts/07_predict.py --lat 34.05 --lon -118.24")
|
| 87 |
+
print("="*60 + "\n")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
scripts/01_extract_data.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 01: Extract Data from SQLite Database
|
| 3 |
+
|
| 4 |
+
This script connects to the FPA FOD SQLite database, extracts the Fires table,
|
| 5 |
+
and saves it as a Parquet file for faster subsequent processing.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/01_extract_data.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sqlite3
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
# Add project root to path
|
| 18 |
+
project_root = Path(__file__).parent.parent
|
| 19 |
+
sys.path.insert(0, str(project_root))
|
| 20 |
+
|
| 21 |
+
from config.config import (
|
| 22 |
+
SQLITE_DB_PATH,
|
| 23 |
+
PROCESSED_DATA_DIR,
|
| 24 |
+
RAW_PARQUET
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def connect_to_database(db_path: Path) -> sqlite3.Connection:
|
| 29 |
+
"""Connect to SQLite database."""
|
| 30 |
+
if not db_path.exists():
|
| 31 |
+
raise FileNotFoundError(
|
| 32 |
+
f"Database not found at {db_path}. "
|
| 33 |
+
"Please ensure the SQLite file is in the project root."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
print(f"Connecting to database: {db_path}")
|
| 37 |
+
return sqlite3.connect(db_path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_table_info(conn: sqlite3.Connection) -> None:
|
| 41 |
+
"""Print information about tables in the database."""
|
| 42 |
+
cursor = conn.cursor()
|
| 43 |
+
|
| 44 |
+
# Get list of user tables (skip SpatiaLite system tables)
|
| 45 |
+
cursor.execute("""
|
| 46 |
+
SELECT name FROM sqlite_master
|
| 47 |
+
WHERE type='table' AND name NOT LIKE 'sqlite_%'
|
| 48 |
+
AND name NOT LIKE 'spatial%'
|
| 49 |
+
AND name NOT LIKE 'virt%'
|
| 50 |
+
AND name NOT LIKE 'view%'
|
| 51 |
+
AND name NOT LIKE 'geometry%'
|
| 52 |
+
""")
|
| 53 |
+
tables = cursor.fetchall()
|
| 54 |
+
|
| 55 |
+
print("\n" + "="*60)
|
| 56 |
+
print("DATABASE TABLES")
|
| 57 |
+
print("="*60)
|
| 58 |
+
|
| 59 |
+
for table in tables:
|
| 60 |
+
table_name = table[0]
|
| 61 |
+
try:
|
| 62 |
+
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
| 63 |
+
count = cursor.fetchone()[0]
|
| 64 |
+
print(f" {table_name}: {count:,} rows")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f" {table_name}: (could not read - {str(e)[:30]})")
|
| 67 |
+
|
| 68 |
+
print("="*60 + "\n")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def extract_fires_table(conn: sqlite3.Connection) -> pd.DataFrame:
|
| 72 |
+
"""Extract the Fires table from the database."""
|
| 73 |
+
print("Extracting Fires table...")
|
| 74 |
+
|
| 75 |
+
query = "SELECT * FROM Fires"
|
| 76 |
+
df = pd.read_sql_query(query, conn)
|
| 77 |
+
|
| 78 |
+
print(f" Loaded {len(df):,} records")
|
| 79 |
+
print(f" Columns: {len(df.columns)}")
|
| 80 |
+
print(f" Memory usage: {df.memory_usage(deep=True).sum() / 1e6:.1f} MB")
|
| 81 |
+
|
| 82 |
+
return df
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_to_parquet(df: pd.DataFrame, output_path: Path) -> None:
|
| 86 |
+
"""Save DataFrame to Parquet format."""
|
| 87 |
+
# Create directory if it doesn't exist
|
| 88 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
print(f"\nSaving to Parquet: {output_path}")
|
| 91 |
+
df.to_parquet(output_path, index=False, compression='snappy')
|
| 92 |
+
|
| 93 |
+
file_size_mb = output_path.stat().st_size / 1e6
|
| 94 |
+
print(f" File size: {file_size_mb:.1f} MB")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def print_data_summary(df: pd.DataFrame) -> None:
|
| 98 |
+
"""Print summary statistics of the extracted data."""
|
| 99 |
+
print("\n" + "="*60)
|
| 100 |
+
print("DATA SUMMARY")
|
| 101 |
+
print("="*60)
|
| 102 |
+
|
| 103 |
+
print(f"\nDate Range: {df['FIRE_YEAR'].min()} - {df['FIRE_YEAR'].max()}")
|
| 104 |
+
|
| 105 |
+
print("\nFire Size Class Distribution:")
|
| 106 |
+
size_dist = df['FIRE_SIZE_CLASS'].value_counts().sort_index()
|
| 107 |
+
for cls, count in size_dist.items():
|
| 108 |
+
pct = count / len(df) * 100
|
| 109 |
+
print(f" Class {cls}: {count:>10,} ({pct:>5.1f}%)")
|
| 110 |
+
|
| 111 |
+
print("\nTop 10 States by Fire Count:")
|
| 112 |
+
state_dist = df['STATE'].value_counts().head(10)
|
| 113 |
+
for state, count in state_dist.items():
|
| 114 |
+
print(f" {state}: {count:,}")
|
| 115 |
+
|
| 116 |
+
print("\nTop Causes:")
|
| 117 |
+
cause_dist = df['STAT_CAUSE_DESCR'].value_counts().head(5)
|
| 118 |
+
for cause, count in cause_dist.items():
|
| 119 |
+
pct = count / len(df) * 100
|
| 120 |
+
print(f" {cause}: {count:,} ({pct:.1f}%)")
|
| 121 |
+
|
| 122 |
+
print("\nMissing Values (top 10 columns):")
|
| 123 |
+
missing = df.isnull().sum().sort_values(ascending=False).head(10)
|
| 124 |
+
for col, count in missing.items():
|
| 125 |
+
if count > 0:
|
| 126 |
+
pct = count / len(df) * 100
|
| 127 |
+
print(f" {col}: {count:,} ({pct:.1f}%)")
|
| 128 |
+
|
| 129 |
+
print("="*60 + "\n")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
"""Main extraction pipeline."""
|
| 134 |
+
print("\n" + "="*60)
|
| 135 |
+
print("WILDFIRE DATA EXTRACTION")
|
| 136 |
+
print("="*60 + "\n")
|
| 137 |
+
|
| 138 |
+
# Connect to database
|
| 139 |
+
conn = connect_to_database(SQLITE_DB_PATH)
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Show database info
|
| 143 |
+
get_table_info(conn)
|
| 144 |
+
|
| 145 |
+
# Extract Fires table
|
| 146 |
+
df = extract_fires_table(conn)
|
| 147 |
+
|
| 148 |
+
# Print summary
|
| 149 |
+
print_data_summary(df)
|
| 150 |
+
|
| 151 |
+
# Save to Parquet
|
| 152 |
+
save_to_parquet(df, RAW_PARQUET)
|
| 153 |
+
|
| 154 |
+
print("\n✓ Data extraction complete!")
|
| 155 |
+
print(f" Output: {RAW_PARQUET}")
|
| 156 |
+
|
| 157 |
+
finally:
|
| 158 |
+
conn.close()
|
| 159 |
+
print(" Database connection closed.")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
scripts/02_eda.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 02: Exploratory Data Analysis (EDA)
|
| 3 |
+
|
| 4 |
+
This script performs comprehensive EDA on the wildfire dataset:
|
| 5 |
+
- Class distribution analysis (original 7 classes and grouped 3 classes)
|
| 6 |
+
- Geographic distribution of fires
|
| 7 |
+
- Temporal patterns (yearly, monthly, seasonal)
|
| 8 |
+
- Missing value analysis
|
| 9 |
+
- Feature correlations
|
| 10 |
+
|
| 11 |
+
Generates visualization plots saved to reports/figures/
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python scripts/02_eda.py
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
|
| 25 |
+
# Add project root to path
|
| 26 |
+
project_root = Path(__file__).parent.parent
|
| 27 |
+
sys.path.insert(0, str(project_root))
|
| 28 |
+
|
| 29 |
+
from config.config import (
|
| 30 |
+
RAW_PARQUET,
|
| 31 |
+
FIGURES_DIR,
|
| 32 |
+
FIRE_SIZE_CLASS_MAPPING,
|
| 33 |
+
TARGET_CLASS_NAMES
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Set style
|
| 37 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 38 |
+
sns.set_palette("husl")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_data() -> pd.DataFrame:
|
| 42 |
+
"""Load the raw parquet data."""
|
| 43 |
+
print("Loading data...")
|
| 44 |
+
df = pd.read_parquet(RAW_PARQUET)
|
| 45 |
+
print(f" Loaded {len(df):,} records")
|
| 46 |
+
return df
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def analyze_class_distribution(df: pd.DataFrame) -> None:
|
| 50 |
+
"""Analyze and visualize fire size class distribution."""
|
| 51 |
+
print("\n" + "="*60)
|
| 52 |
+
print("CLASS DISTRIBUTION ANALYSIS")
|
| 53 |
+
print("="*60)
|
| 54 |
+
|
| 55 |
+
# Original 7 classes
|
| 56 |
+
print("\nOriginal Fire Size Classes:")
|
| 57 |
+
original_dist = df['FIRE_SIZE_CLASS'].value_counts().sort_index()
|
| 58 |
+
for cls, count in original_dist.items():
|
| 59 |
+
pct = count / len(df) * 100
|
| 60 |
+
print(f" Class {cls}: {count:>10,} ({pct:>6.2f}%)")
|
| 61 |
+
|
| 62 |
+
# Grouped 3 classes
|
| 63 |
+
df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
|
| 64 |
+
|
| 65 |
+
print("\nGrouped Classes (Target Variable):")
|
| 66 |
+
grouped_dist = df['fire_size_grouped'].value_counts().sort_index()
|
| 67 |
+
for cls_idx, count in grouped_dist.items():
|
| 68 |
+
pct = count / len(df) * 100
|
| 69 |
+
cls_name = TARGET_CLASS_NAMES[cls_idx]
|
| 70 |
+
print(f" {cls_idx} ({cls_name:>6}): {count:>10,} ({pct:>6.2f}%)")
|
| 71 |
+
|
| 72 |
+
# Visualize
|
| 73 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 74 |
+
|
| 75 |
+
# Original distribution
|
| 76 |
+
colors_orig = sns.color_palette("YlOrRd", 7)
|
| 77 |
+
ax1 = axes[0]
|
| 78 |
+
original_dist.plot(kind='bar', ax=ax1, color=colors_orig, edgecolor='black')
|
| 79 |
+
ax1.set_title('Original Fire Size Class Distribution', fontsize=14, fontweight='bold')
|
| 80 |
+
ax1.set_xlabel('Fire Size Class')
|
| 81 |
+
ax1.set_ylabel('Count')
|
| 82 |
+
ax1.tick_params(axis='x', rotation=0)
|
| 83 |
+
|
| 84 |
+
# Add percentage labels
|
| 85 |
+
for i, (idx, val) in enumerate(original_dist.items()):
|
| 86 |
+
pct = val / len(df) * 100
|
| 87 |
+
ax1.annotate(f'{pct:.1f}%', (i, val), ha='center', va='bottom', fontsize=9)
|
| 88 |
+
|
| 89 |
+
# Grouped distribution
|
| 90 |
+
colors_grouped = ['#2ecc71', '#f39c12', '#e74c3c'] # Green, Orange, Red
|
| 91 |
+
ax2 = axes[1]
|
| 92 |
+
grouped_dist.plot(kind='bar', ax=ax2, color=colors_grouped, edgecolor='black')
|
| 93 |
+
ax2.set_title('Grouped Fire Size Distribution (Target)', fontsize=14, fontweight='bold')
|
| 94 |
+
ax2.set_xlabel('Fire Size Category')
|
| 95 |
+
ax2.set_ylabel('Count')
|
| 96 |
+
ax2.set_xticklabels(TARGET_CLASS_NAMES, rotation=0)
|
| 97 |
+
|
| 98 |
+
# Add percentage labels
|
| 99 |
+
for i, (idx, val) in enumerate(grouped_dist.items()):
|
| 100 |
+
pct = val / len(df) * 100
|
| 101 |
+
ax2.annotate(f'{pct:.1f}%', (i, val), ha='center', va='bottom', fontsize=10)
|
| 102 |
+
|
| 103 |
+
plt.tight_layout()
|
| 104 |
+
plt.savefig(FIGURES_DIR / 'class_distribution.png', dpi=150, bbox_inches='tight')
|
| 105 |
+
plt.close()
|
| 106 |
+
print(f"\n Saved: class_distribution.png")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def analyze_geographic_distribution(df: pd.DataFrame) -> None:
|
| 110 |
+
"""Analyze and visualize geographic distribution of fires."""
|
| 111 |
+
print("\n" + "="*60)
|
| 112 |
+
print("GEOGRAPHIC DISTRIBUTION")
|
| 113 |
+
print("="*60)
|
| 114 |
+
|
| 115 |
+
# Top states
|
| 116 |
+
print("\nTop 15 States by Fire Count:")
|
| 117 |
+
state_dist = df['STATE'].value_counts().head(15)
|
| 118 |
+
for state, count in state_dist.items():
|
| 119 |
+
pct = count / len(df) * 100
|
| 120 |
+
print(f" {state}: {count:>10,} ({pct:>5.1f}%)")
|
| 121 |
+
|
| 122 |
+
# Fire locations scatter plot
|
| 123 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
| 124 |
+
|
| 125 |
+
# All fires (sampled for performance)
|
| 126 |
+
sample_size = min(100000, len(df))
|
| 127 |
+
df_sample = df.sample(n=sample_size, random_state=42)
|
| 128 |
+
|
| 129 |
+
ax1 = axes[0]
|
| 130 |
+
scatter = ax1.scatter(
|
| 131 |
+
df_sample['LONGITUDE'],
|
| 132 |
+
df_sample['LATITUDE'],
|
| 133 |
+
c=df_sample['FIRE_SIZE_CLASS'].map({'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6}),
|
| 134 |
+
cmap='YlOrRd',
|
| 135 |
+
alpha=0.3,
|
| 136 |
+
s=1
|
| 137 |
+
)
|
| 138 |
+
ax1.set_title(f'Fire Locations (n={sample_size:,} sample)', fontsize=14, fontweight='bold')
|
| 139 |
+
ax1.set_xlabel('Longitude')
|
| 140 |
+
ax1.set_ylabel('Latitude')
|
| 141 |
+
ax1.set_xlim(-130, -65)
|
| 142 |
+
ax1.set_ylim(24, 50)
|
| 143 |
+
plt.colorbar(scatter, ax=ax1, label='Fire Size Class (A=0 to G=6)')
|
| 144 |
+
|
| 145 |
+
# Large fires only (E, F, G)
|
| 146 |
+
df_large = df[df['FIRE_SIZE_CLASS'].isin(['E', 'F', 'G'])]
|
| 147 |
+
|
| 148 |
+
ax2 = axes[1]
|
| 149 |
+
scatter2 = ax2.scatter(
|
| 150 |
+
df_large['LONGITUDE'],
|
| 151 |
+
df_large['LATITUDE'],
|
| 152 |
+
c=df_large['FIRE_SIZE_CLASS'].map({'E': 0, 'F': 1, 'G': 2}),
|
| 153 |
+
cmap='Reds',
|
| 154 |
+
alpha=0.5,
|
| 155 |
+
s=5
|
| 156 |
+
)
|
| 157 |
+
ax2.set_title(f'Large Fires Only (E/F/G, n={len(df_large):,})', fontsize=14, fontweight='bold')
|
| 158 |
+
ax2.set_xlabel('Longitude')
|
| 159 |
+
ax2.set_ylabel('Latitude')
|
| 160 |
+
ax2.set_xlim(-130, -65)
|
| 161 |
+
ax2.set_ylim(24, 50)
|
| 162 |
+
|
| 163 |
+
plt.tight_layout()
|
| 164 |
+
plt.savefig(FIGURES_DIR / 'geographic_distribution.png', dpi=150, bbox_inches='tight')
|
| 165 |
+
plt.close()
|
| 166 |
+
print(f"\n Saved: geographic_distribution.png")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def analyze_temporal_patterns(df: pd.DataFrame) -> None:
|
| 170 |
+
"""Analyze temporal patterns in the data."""
|
| 171 |
+
print("\n" + "="*60)
|
| 172 |
+
print("TEMPORAL PATTERNS")
|
| 173 |
+
print("="*60)
|
| 174 |
+
|
| 175 |
+
# Convert discovery day of year to month
|
| 176 |
+
df['month'] = pd.to_datetime(df['DISCOVERY_DOY'], format='%j').dt.month
|
| 177 |
+
|
| 178 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 179 |
+
|
| 180 |
+
# Yearly trend
|
| 181 |
+
ax1 = axes[0, 0]
|
| 182 |
+
yearly = df.groupby('FIRE_YEAR').size()
|
| 183 |
+
yearly.plot(kind='line', ax=ax1, marker='o', linewidth=2, markersize=4)
|
| 184 |
+
ax1.set_title('Fires per Year', fontsize=12, fontweight='bold')
|
| 185 |
+
ax1.set_xlabel('Year')
|
| 186 |
+
ax1.set_ylabel('Number of Fires')
|
| 187 |
+
ax1.grid(True, alpha=0.3)
|
| 188 |
+
|
| 189 |
+
# Monthly distribution
|
| 190 |
+
ax2 = axes[0, 1]
|
| 191 |
+
monthly = df.groupby('month').size()
|
| 192 |
+
monthly.plot(kind='bar', ax=ax2, color='coral', edgecolor='black')
|
| 193 |
+
ax2.set_title('Fires by Month', fontsize=12, fontweight='bold')
|
| 194 |
+
ax2.set_xlabel('Month')
|
| 195 |
+
ax2.set_ylabel('Number of Fires')
|
| 196 |
+
ax2.tick_params(axis='x', rotation=0)
|
| 197 |
+
|
| 198 |
+
# Large fires by month
|
| 199 |
+
ax3 = axes[1, 0]
|
| 200 |
+
df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
|
| 201 |
+
monthly_by_class = df.groupby(['month', 'fire_size_grouped']).size().unstack(fill_value=0)
|
| 202 |
+
monthly_by_class.columns = TARGET_CLASS_NAMES
|
| 203 |
+
monthly_by_class.plot(kind='bar', ax=ax3, width=0.8,
|
| 204 |
+
color=['#2ecc71', '#f39c12', '#e74c3c'], edgecolor='black')
|
| 205 |
+
ax3.set_title('Fire Size Category by Month', fontsize=12, fontweight='bold')
|
| 206 |
+
ax3.set_xlabel('Month')
|
| 207 |
+
ax3.set_ylabel('Number of Fires')
|
| 208 |
+
ax3.tick_params(axis='x', rotation=0)
|
| 209 |
+
ax3.legend(title='Size Category')
|
| 210 |
+
|
| 211 |
+
# Fire causes
|
| 212 |
+
ax4 = axes[1, 1]
|
| 213 |
+
cause_dist = df['STAT_CAUSE_DESCR'].value_counts().head(10)
|
| 214 |
+
cause_dist.plot(kind='barh', ax=ax4, color='steelblue', edgecolor='black')
|
| 215 |
+
ax4.set_title('Top 10 Fire Causes', fontsize=12, fontweight='bold')
|
| 216 |
+
ax4.set_xlabel('Number of Fires')
|
| 217 |
+
ax4.invert_yaxis()
|
| 218 |
+
|
| 219 |
+
plt.tight_layout()
|
| 220 |
+
plt.savefig(FIGURES_DIR / 'temporal_patterns.png', dpi=150, bbox_inches='tight')
|
| 221 |
+
plt.close()
|
| 222 |
+
print(f"\n Saved: temporal_patterns.png")
|
| 223 |
+
|
| 224 |
+
# Print monthly stats
|
| 225 |
+
print("\nFires by Month:")
|
| 226 |
+
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
| 227 |
+
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
| 228 |
+
for month, count in monthly.items():
|
| 229 |
+
pct = count / len(df) * 100
|
| 230 |
+
print(f" {month_names[month-1]}: {count:>10,} ({pct:>5.1f}%)")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def analyze_missing_values(df: pd.DataFrame) -> None:
|
| 234 |
+
"""Analyze missing values in the dataset."""
|
| 235 |
+
print("\n" + "="*60)
|
| 236 |
+
print("MISSING VALUE ANALYSIS")
|
| 237 |
+
print("="*60)
|
| 238 |
+
|
| 239 |
+
missing = df.isnull().sum()
|
| 240 |
+
missing_pct = (missing / len(df) * 100).round(2)
|
| 241 |
+
|
| 242 |
+
missing_df = pd.DataFrame({
|
| 243 |
+
'Missing Count': missing,
|
| 244 |
+
'Missing %': missing_pct
|
| 245 |
+
}).sort_values('Missing Count', ascending=False)
|
| 246 |
+
|
| 247 |
+
# Only show columns with missing values
|
| 248 |
+
missing_df = missing_df[missing_df['Missing Count'] > 0]
|
| 249 |
+
|
| 250 |
+
print(f"\nColumns with missing values: {len(missing_df)}")
|
| 251 |
+
print("\nTop 20 columns with missing values:")
|
| 252 |
+
for col, row in missing_df.head(20).iterrows():
|
| 253 |
+
print(f" {col}: {row['Missing Count']:,} ({row['Missing %']:.1f}%)")
|
| 254 |
+
|
| 255 |
+
# Visualize
|
| 256 |
+
if len(missing_df) > 0:
|
| 257 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 258 |
+
missing_df.head(20)['Missing %'].plot(
|
| 259 |
+
kind='barh', ax=ax, color='salmon', edgecolor='black'
|
| 260 |
+
)
|
| 261 |
+
ax.set_title('Missing Values by Column (Top 20)', fontsize=14, fontweight='bold')
|
| 262 |
+
ax.set_xlabel('Missing %')
|
| 263 |
+
ax.invert_yaxis()
|
| 264 |
+
|
| 265 |
+
plt.tight_layout()
|
| 266 |
+
plt.savefig(FIGURES_DIR / 'missing_values.png', dpi=150, bbox_inches='tight')
|
| 267 |
+
plt.close()
|
| 268 |
+
print(f"\n Saved: missing_values.png")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def analyze_cause_by_size(df: pd.DataFrame) -> None:
|
| 272 |
+
"""Analyze fire causes by fire size category."""
|
| 273 |
+
print("\n" + "="*60)
|
| 274 |
+
print("FIRE CAUSE BY SIZE ANALYSIS")
|
| 275 |
+
print("="*60)
|
| 276 |
+
|
| 277 |
+
df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
|
| 278 |
+
|
| 279 |
+
# Cross-tabulation
|
| 280 |
+
cause_size = pd.crosstab(
|
| 281 |
+
df['STAT_CAUSE_DESCR'],
|
| 282 |
+
df['fire_size_grouped'],
|
| 283 |
+
normalize='index'
|
| 284 |
+
) * 100
|
| 285 |
+
cause_size.columns = TARGET_CLASS_NAMES
|
| 286 |
+
|
| 287 |
+
print("\nFire Cause Distribution by Size Category (% of each cause):")
|
| 288 |
+
print(cause_size.round(1).to_string())
|
| 289 |
+
|
| 290 |
+
# Visualize
|
| 291 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 292 |
+
cause_size.plot(kind='barh', ax=ax, stacked=True,
|
| 293 |
+
color=['#2ecc71', '#f39c12', '#e74c3c'], edgecolor='white')
|
| 294 |
+
ax.set_title('Fire Size Distribution by Cause', fontsize=14, fontweight='bold')
|
| 295 |
+
ax.set_xlabel('Percentage')
|
| 296 |
+
ax.legend(title='Size Category', loc='lower right')
|
| 297 |
+
ax.invert_yaxis()
|
| 298 |
+
|
| 299 |
+
plt.tight_layout()
|
| 300 |
+
plt.savefig(FIGURES_DIR / 'cause_by_size.png', dpi=150, bbox_inches='tight')
|
| 301 |
+
plt.close()
|
| 302 |
+
print(f"\n Saved: cause_by_size.png")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def analyze_owner_distribution(df: pd.DataFrame) -> None:
|
| 306 |
+
"""Analyze land owner distribution."""
|
| 307 |
+
print("\n" + "="*60)
|
| 308 |
+
print("LAND OWNER ANALYSIS")
|
| 309 |
+
print("="*60)
|
| 310 |
+
|
| 311 |
+
owner_dist = df['OWNER_DESCR'].value_counts()
|
| 312 |
+
print("\nFires by Land Owner:")
|
| 313 |
+
for owner, count in owner_dist.head(10).items():
|
| 314 |
+
pct = count / len(df) * 100
|
| 315 |
+
print(f" {owner}: {count:,} ({pct:.1f}%)")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def main():
|
| 319 |
+
"""Main EDA pipeline."""
|
| 320 |
+
print("\n" + "="*60)
|
| 321 |
+
print("EXPLORATORY DATA ANALYSIS")
|
| 322 |
+
print("="*60)
|
| 323 |
+
|
| 324 |
+
# Create figures directory
|
| 325 |
+
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
# Load data
|
| 328 |
+
df = load_data()
|
| 329 |
+
|
| 330 |
+
# Run analyses
|
| 331 |
+
analyze_class_distribution(df)
|
| 332 |
+
analyze_geographic_distribution(df)
|
| 333 |
+
analyze_temporal_patterns(df)
|
| 334 |
+
analyze_missing_values(df)
|
| 335 |
+
analyze_cause_by_size(df)
|
| 336 |
+
analyze_owner_distribution(df)
|
| 337 |
+
|
| 338 |
+
print("\n" + "="*60)
|
| 339 |
+
print("✓ EDA Complete!")
|
| 340 |
+
print(f" Figures saved to: {FIGURES_DIR}")
|
| 341 |
+
print("="*60 + "\n")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
if __name__ == "__main__":
|
| 345 |
+
main()
|
scripts/03_preprocess.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 03: Data Preprocessing
|
| 3 |
+
|
| 4 |
+
This script preprocesses the raw wildfire data:
|
| 5 |
+
- Creates ordinal target variable (3 classes: Small, Medium, Large)
|
| 6 |
+
- Drops irrelevant columns (IDs, text fields, redundant info)
|
| 7 |
+
- Handles missing values
|
| 8 |
+
- Encodes categorical variables
|
| 9 |
+
- Splits data into train/test sets (stratified)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python scripts/03_preprocess.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from sklearn.model_selection import train_test_split
|
| 21 |
+
from sklearn.preprocessing import LabelEncoder
|
| 22 |
+
|
| 23 |
+
# Add project root to path
|
| 24 |
+
project_root = Path(__file__).parent.parent
|
| 25 |
+
sys.path.insert(0, str(project_root))
|
| 26 |
+
|
| 27 |
+
from config.config import (
|
| 28 |
+
RAW_PARQUET,
|
| 29 |
+
PROCESSED_PARQUET,
|
| 30 |
+
TRAIN_PARQUET,
|
| 31 |
+
TEST_PARQUET,
|
| 32 |
+
PROCESSED_DATA_DIR,
|
| 33 |
+
FIRE_SIZE_CLASS_MAPPING,
|
| 34 |
+
TARGET_CLASS_NAMES,
|
| 35 |
+
TARGET_COLUMN,
|
| 36 |
+
COLUMNS_TO_DROP,
|
| 37 |
+
CATEGORICAL_FEATURES,
|
| 38 |
+
RANDOM_STATE,
|
| 39 |
+
TEST_SIZE
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_data() -> pd.DataFrame:
|
| 44 |
+
"""Load the raw parquet data."""
|
| 45 |
+
print("Loading raw data...")
|
| 46 |
+
df = pd.read_parquet(RAW_PARQUET)
|
| 47 |
+
print(f" Loaded {len(df):,} records with {len(df.columns)} columns")
|
| 48 |
+
return df
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def create_target_variable(df: pd.DataFrame) -> pd.DataFrame:
|
| 52 |
+
"""Create ordinal target variable from FIRE_SIZE_CLASS."""
|
| 53 |
+
print("\nCreating ordinal target variable...")
|
| 54 |
+
|
| 55 |
+
# Map original classes to ordinal (0, 1, 2)
|
| 56 |
+
df[TARGET_COLUMN] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
|
| 57 |
+
|
| 58 |
+
# Check for unmapped values
|
| 59 |
+
unmapped = df[TARGET_COLUMN].isna().sum()
|
| 60 |
+
if unmapped > 0:
|
| 61 |
+
print(f" Warning: {unmapped} records could not be mapped. Dropping...")
|
| 62 |
+
df = df.dropna(subset=[TARGET_COLUMN])
|
| 63 |
+
|
| 64 |
+
df[TARGET_COLUMN] = df[TARGET_COLUMN].astype(int)
|
| 65 |
+
|
| 66 |
+
# Print distribution
|
| 67 |
+
print("\n Target Variable Distribution:")
|
| 68 |
+
for val in sorted(df[TARGET_COLUMN].unique()):
|
| 69 |
+
count = (df[TARGET_COLUMN] == val).sum()
|
| 70 |
+
pct = count / len(df) * 100
|
| 71 |
+
print(f" {val} ({TARGET_CLASS_NAMES[val]}): {count:,} ({pct:.2f}%)")
|
| 72 |
+
|
| 73 |
+
return df
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def drop_irrelevant_columns(df: pd.DataFrame) -> pd.DataFrame:
|
| 77 |
+
"""Drop columns not useful for prediction."""
|
| 78 |
+
print("\nDropping irrelevant columns...")
|
| 79 |
+
|
| 80 |
+
# Get columns that exist in the dataframe
|
| 81 |
+
cols_to_drop = [col for col in COLUMNS_TO_DROP if col in df.columns]
|
| 82 |
+
|
| 83 |
+
print(f" Dropping {len(cols_to_drop)} columns:")
|
| 84 |
+
for col in cols_to_drop[:10]:
|
| 85 |
+
print(f" - {col}")
|
| 86 |
+
if len(cols_to_drop) > 10:
|
| 87 |
+
print(f" ... and {len(cols_to_drop) - 10} more")
|
| 88 |
+
|
| 89 |
+
df = df.drop(columns=cols_to_drop, errors='ignore')
|
| 90 |
+
print(f" Remaining columns: {len(df.columns)}")
|
| 91 |
+
|
| 92 |
+
return df
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def handle_missing_values(df: pd.DataFrame) -> pd.DataFrame:
|
| 96 |
+
"""Handle missing values in the dataset."""
|
| 97 |
+
print("\nHandling missing values...")
|
| 98 |
+
|
| 99 |
+
initial_rows = len(df)
|
| 100 |
+
|
| 101 |
+
# Check missing in essential columns
|
| 102 |
+
essential_cols = ['LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY', TARGET_COLUMN]
|
| 103 |
+
for col in essential_cols:
|
| 104 |
+
if col in df.columns:
|
| 105 |
+
missing = df[col].isna().sum()
|
| 106 |
+
if missing > 0:
|
| 107 |
+
print(f" {col}: {missing} missing values")
|
| 108 |
+
|
| 109 |
+
# Drop rows with missing essential values
|
| 110 |
+
df = df.dropna(subset=[c for c in essential_cols if c in df.columns])
|
| 111 |
+
|
| 112 |
+
# For categorical features, fill with 'Unknown'
|
| 113 |
+
for col in CATEGORICAL_FEATURES:
|
| 114 |
+
if col in df.columns:
|
| 115 |
+
missing = df[col].isna().sum()
|
| 116 |
+
if missing > 0:
|
| 117 |
+
df[col] = df[col].fillna('Unknown')
|
| 118 |
+
print(f" {col}: Filled {missing} missing with 'Unknown'")
|
| 119 |
+
|
| 120 |
+
rows_dropped = initial_rows - len(df)
|
| 121 |
+
print(f"\n Rows dropped due to missing essential values: {rows_dropped:,}")
|
| 122 |
+
print(f" Remaining rows: {len(df):,}")
|
| 123 |
+
|
| 124 |
+
return df
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def encode_categorical_features(df: pd.DataFrame) -> tuple[pd.DataFrame, dict]:
|
| 128 |
+
"""Encode categorical features using Label Encoding."""
|
| 129 |
+
print("\nEncoding categorical features...")
|
| 130 |
+
|
| 131 |
+
encoders = {}
|
| 132 |
+
|
| 133 |
+
for col in CATEGORICAL_FEATURES:
|
| 134 |
+
if col in df.columns:
|
| 135 |
+
le = LabelEncoder()
|
| 136 |
+
df[f'{col}_encoded'] = le.fit_transform(df[col].astype(str))
|
| 137 |
+
encoders[col] = le
|
| 138 |
+
|
| 139 |
+
n_categories = len(le.classes_)
|
| 140 |
+
print(f" {col}: {n_categories} categories")
|
| 141 |
+
|
| 142 |
+
return df, encoders
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def select_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 146 |
+
"""Select features for modeling."""
|
| 147 |
+
print("\nSelecting features for modeling...")
|
| 148 |
+
|
| 149 |
+
# Features to keep
|
| 150 |
+
feature_cols = [
|
| 151 |
+
# Numerical
|
| 152 |
+
'LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY',
|
| 153 |
+
# Encoded categorical
|
| 154 |
+
'NWCG_REPORTING_AGENCY_encoded',
|
| 155 |
+
'STAT_CAUSE_DESCR_encoded',
|
| 156 |
+
'STATE_encoded',
|
| 157 |
+
'OWNER_DESCR_encoded',
|
| 158 |
+
# Target
|
| 159 |
+
TARGET_COLUMN
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# Keep only columns that exist
|
| 163 |
+
available_cols = [col for col in feature_cols if col in df.columns]
|
| 164 |
+
|
| 165 |
+
# Also keep original categorical columns for reference
|
| 166 |
+
original_cats = [col for col in CATEGORICAL_FEATURES if col in df.columns]
|
| 167 |
+
|
| 168 |
+
all_cols = available_cols + original_cats
|
| 169 |
+
all_cols = list(dict.fromkeys(all_cols)) # Remove duplicates, preserve order
|
| 170 |
+
|
| 171 |
+
df = df[all_cols]
|
| 172 |
+
|
| 173 |
+
print(f" Selected {len(available_cols)} feature columns + target")
|
| 174 |
+
print(f" Final columns: {list(df.columns)}")
|
| 175 |
+
|
| 176 |
+
return df
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def split_data(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 180 |
+
"""Split data into train and test sets."""
|
| 181 |
+
print("\nSplitting data into train/test sets...")
|
| 182 |
+
|
| 183 |
+
train_df, test_df = train_test_split(
|
| 184 |
+
df,
|
| 185 |
+
test_size=TEST_SIZE,
|
| 186 |
+
random_state=RANDOM_STATE,
|
| 187 |
+
stratify=df[TARGET_COLUMN]
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
print(f" Train set: {len(train_df):,} rows ({100*(1-TEST_SIZE):.0f}%)")
|
| 191 |
+
print(f" Test set: {len(test_df):,} rows ({100*TEST_SIZE:.0f}%)")
|
| 192 |
+
|
| 193 |
+
# Verify stratification
|
| 194 |
+
print("\n Target distribution in splits:")
|
| 195 |
+
for name, data in [('Train', train_df), ('Test', test_df)]:
|
| 196 |
+
dist = data[TARGET_COLUMN].value_counts(normalize=True).sort_index() * 100
|
| 197 |
+
dist_str = ", ".join([f"{TARGET_CLASS_NAMES[i]}: {v:.1f}%" for i, v in dist.items()])
|
| 198 |
+
print(f" {name}: {dist_str}")
|
| 199 |
+
|
| 200 |
+
return train_df, test_df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def save_data(df: pd.DataFrame, train_df: pd.DataFrame, test_df: pd.DataFrame) -> None:
|
| 204 |
+
"""Save processed data to parquet files."""
|
| 205 |
+
print("\nSaving processed data...")
|
| 206 |
+
|
| 207 |
+
# Create directory if needed
|
| 208 |
+
PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
|
| 210 |
+
# Save full processed data
|
| 211 |
+
df.to_parquet(PROCESSED_PARQUET, index=False)
|
| 212 |
+
print(f" Full processed data: {PROCESSED_PARQUET}")
|
| 213 |
+
|
| 214 |
+
# Save train/test splits
|
| 215 |
+
train_df.to_parquet(TRAIN_PARQUET, index=False)
|
| 216 |
+
print(f" Train data: {TRAIN_PARQUET}")
|
| 217 |
+
|
| 218 |
+
test_df.to_parquet(TEST_PARQUET, index=False)
|
| 219 |
+
print(f" Test data: {TEST_PARQUET}")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def print_summary(df: pd.DataFrame) -> None:
|
| 223 |
+
"""Print preprocessing summary."""
|
| 224 |
+
print("\n" + "="*60)
|
| 225 |
+
print("PREPROCESSING SUMMARY")
|
| 226 |
+
print("="*60)
|
| 227 |
+
|
| 228 |
+
print(f"\nDataset shape: {df.shape}")
|
| 229 |
+
print(f"\nColumn types:")
|
| 230 |
+
print(df.dtypes.value_counts().to_string())
|
| 231 |
+
|
| 232 |
+
print(f"\nFeature statistics:")
|
| 233 |
+
numerical_cols = df.select_dtypes(include=[np.number]).columns
|
| 234 |
+
for col in numerical_cols:
|
| 235 |
+
if col != TARGET_COLUMN:
|
| 236 |
+
print(f" {col}:")
|
| 237 |
+
print(f" Range: [{df[col].min():.2f}, {df[col].max():.2f}]")
|
| 238 |
+
print(f" Mean: {df[col].mean():.2f}, Std: {df[col].std():.2f}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def main():
|
| 242 |
+
"""Main preprocessing pipeline."""
|
| 243 |
+
print("\n" + "="*60)
|
| 244 |
+
print("DATA PREPROCESSING")
|
| 245 |
+
print("="*60)
|
| 246 |
+
|
| 247 |
+
# Load data
|
| 248 |
+
df = load_data()
|
| 249 |
+
|
| 250 |
+
# Create target variable
|
| 251 |
+
df = create_target_variable(df)
|
| 252 |
+
|
| 253 |
+
# Drop irrelevant columns
|
| 254 |
+
df = drop_irrelevant_columns(df)
|
| 255 |
+
|
| 256 |
+
# Handle missing values
|
| 257 |
+
df = handle_missing_values(df)
|
| 258 |
+
|
| 259 |
+
# Encode categorical features
|
| 260 |
+
df, encoders = encode_categorical_features(df)
|
| 261 |
+
|
| 262 |
+
# Select features
|
| 263 |
+
df = select_features(df)
|
| 264 |
+
|
| 265 |
+
# Split data
|
| 266 |
+
train_df, test_df = split_data(df)
|
| 267 |
+
|
| 268 |
+
# Save data
|
| 269 |
+
save_data(df, train_df, test_df)
|
| 270 |
+
|
| 271 |
+
# Print summary
|
| 272 |
+
print_summary(df)
|
| 273 |
+
|
| 274 |
+
print("\n" + "="*60)
|
| 275 |
+
print("✓ Preprocessing Complete!")
|
| 276 |
+
print("="*60 + "\n")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main()
|
scripts/04_feature_engineering.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 04: Feature Engineering
|
| 3 |
+
|
| 4 |
+
This script creates additional features for the model:
|
| 5 |
+
- Temporal features (month, season, day of week)
|
| 6 |
+
- Geospatial features (lat/lon bins, clustering, interactions)
|
| 7 |
+
- Coordinate transformations
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/04_feature_engineering.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from sklearn.cluster import KMeans
|
| 19 |
+
from sklearn.preprocessing import StandardScaler
|
| 20 |
+
|
| 21 |
+
# Add project root to path
|
| 22 |
+
project_root = Path(__file__).parent.parent
|
| 23 |
+
sys.path.insert(0, str(project_root))
|
| 24 |
+
|
| 25 |
+
from config.config import (
|
| 26 |
+
TRAIN_PARQUET,
|
| 27 |
+
TEST_PARQUET,
|
| 28 |
+
FEATURES_PARQUET,
|
| 29 |
+
PROCESSED_DATA_DIR,
|
| 30 |
+
TARGET_COLUMN,
|
| 31 |
+
N_GEO_CLUSTERS,
|
| 32 |
+
LAT_BINS,
|
| 33 |
+
LON_BINS,
|
| 34 |
+
RANDOM_STATE
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_data() -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 39 |
+
"""Load train and test data."""
|
| 40 |
+
print("Loading data...")
|
| 41 |
+
train_df = pd.read_parquet(TRAIN_PARQUET)
|
| 42 |
+
test_df = pd.read_parquet(TEST_PARQUET)
|
| 43 |
+
print(f" Train: {len(train_df):,} rows")
|
| 44 |
+
print(f" Test: {len(test_df):,} rows")
|
| 45 |
+
return train_df, test_df
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def create_temporal_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 49 |
+
"""Create temporal features from DISCOVERY_DOY."""
|
| 50 |
+
print("\nCreating temporal features...")
|
| 51 |
+
|
| 52 |
+
# Convert day of year to datetime for feature extraction
|
| 53 |
+
# Using a non-leap year as reference
|
| 54 |
+
reference_year = 2001
|
| 55 |
+
df['temp_date'] = pd.to_datetime(
|
| 56 |
+
df['DISCOVERY_DOY'].astype(int).astype(str) + f'-{reference_year}',
|
| 57 |
+
format='%j-%Y',
|
| 58 |
+
errors='coerce'
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Handle invalid dates
|
| 62 |
+
invalid_dates = df['temp_date'].isna().sum()
|
| 63 |
+
if invalid_dates > 0:
|
| 64 |
+
print(f" Warning: {invalid_dates} invalid day of year values")
|
| 65 |
+
# Fill with median day
|
| 66 |
+
median_doy = df['DISCOVERY_DOY'].median()
|
| 67 |
+
df.loc[df['temp_date'].isna(), 'temp_date'] = pd.to_datetime(
|
| 68 |
+
f'{int(median_doy)}-{reference_year}', format='%j-%Y'
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Extract features
|
| 72 |
+
df['month'] = df['temp_date'].dt.month
|
| 73 |
+
df['day_of_week'] = df['temp_date'].dt.dayofweek # 0=Monday, 6=Sunday
|
| 74 |
+
df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
|
| 75 |
+
|
| 76 |
+
# Season (1=Winter, 2=Spring, 3=Summer, 4=Fall)
|
| 77 |
+
df['season'] = df['month'].apply(lambda m:
|
| 78 |
+
1 if m in [12, 1, 2] else
|
| 79 |
+
2 if m in [3, 4, 5] else
|
| 80 |
+
3 if m in [6, 7, 8] else 4
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Fire season indicator (peak fire months: June-October)
|
| 84 |
+
df['is_fire_season'] = df['month'].isin([6, 7, 8, 9, 10]).astype(int)
|
| 85 |
+
|
| 86 |
+
# Drop temporary date column
|
| 87 |
+
df = df.drop(columns=['temp_date'])
|
| 88 |
+
|
| 89 |
+
print(" Created: month, day_of_week, is_weekend, season, is_fire_season")
|
| 90 |
+
|
| 91 |
+
return df
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def create_geospatial_features(train_df: pd.DataFrame, test_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, KMeans]:
|
| 95 |
+
"""Create geospatial features from coordinates."""
|
| 96 |
+
print("\nCreating geospatial features...")
|
| 97 |
+
|
| 98 |
+
# 1. Latitude/Longitude bins
|
| 99 |
+
print(" Creating coordinate bins...")
|
| 100 |
+
|
| 101 |
+
# Define bin edges based on continental US bounds
|
| 102 |
+
lat_min, lat_max = 24.0, 50.0
|
| 103 |
+
lon_min, lon_max = -125.0, -66.0
|
| 104 |
+
|
| 105 |
+
lat_edges = np.linspace(lat_min, lat_max, LAT_BINS + 1)
|
| 106 |
+
lon_edges = np.linspace(lon_min, lon_max, LON_BINS + 1)
|
| 107 |
+
|
| 108 |
+
for df in [train_df, test_df]:
|
| 109 |
+
df['lat_bin'] = pd.cut(df['LATITUDE'], bins=lat_edges, labels=False, include_lowest=True)
|
| 110 |
+
df['lon_bin'] = pd.cut(df['LONGITUDE'], bins=lon_edges, labels=False, include_lowest=True)
|
| 111 |
+
|
| 112 |
+
# Fill NaN bins (locations outside continental US) with nearest bin
|
| 113 |
+
df['lat_bin'] = df['lat_bin'].fillna(df['lat_bin'].median()).astype(int)
|
| 114 |
+
df['lon_bin'] = df['lon_bin'].fillna(df['lon_bin'].median()).astype(int)
|
| 115 |
+
|
| 116 |
+
# 2. Geographic clustering using K-Means
|
| 117 |
+
print(f" Fitting K-Means clustering (k={N_GEO_CLUSTERS})...")
|
| 118 |
+
|
| 119 |
+
# Prepare coordinates for clustering
|
| 120 |
+
train_coords = train_df[['LATITUDE', 'LONGITUDE']].values
|
| 121 |
+
test_coords = test_df[['LATITUDE', 'LONGITUDE']].values
|
| 122 |
+
|
| 123 |
+
# Scale coordinates
|
| 124 |
+
scaler = StandardScaler()
|
| 125 |
+
train_coords_scaled = scaler.fit_transform(train_coords)
|
| 126 |
+
test_coords_scaled = scaler.transform(test_coords)
|
| 127 |
+
|
| 128 |
+
# Fit K-Means on train data
|
| 129 |
+
kmeans = KMeans(n_clusters=N_GEO_CLUSTERS, random_state=RANDOM_STATE, n_init=10)
|
| 130 |
+
train_df['geo_cluster'] = kmeans.fit_predict(train_coords_scaled)
|
| 131 |
+
test_df['geo_cluster'] = kmeans.predict(test_coords_scaled)
|
| 132 |
+
|
| 133 |
+
print(f" Cluster distribution (train):")
|
| 134 |
+
cluster_dist = train_df['geo_cluster'].value_counts().sort_index()
|
| 135 |
+
for cluster, count in cluster_dist.items():
|
| 136 |
+
pct = count / len(train_df) * 100
|
| 137 |
+
if pct >= 3: # Only show clusters with >= 3%
|
| 138 |
+
print(f" Cluster {cluster}: {count:,} ({pct:.1f}%)")
|
| 139 |
+
|
| 140 |
+
# 3. Coordinate interactions
|
| 141 |
+
print(" Creating coordinate interactions...")
|
| 142 |
+
|
| 143 |
+
for df in [train_df, test_df]:
|
| 144 |
+
# Quadratic terms (captures non-linear patterns)
|
| 145 |
+
df['lat_squared'] = df['LATITUDE'] ** 2
|
| 146 |
+
df['lon_squared'] = df['LONGITUDE'] ** 2
|
| 147 |
+
df['lat_lon_interaction'] = df['LATITUDE'] * df['LONGITUDE']
|
| 148 |
+
|
| 149 |
+
# Distance from geographic center of continental US
|
| 150 |
+
# Approximate center: 39.8°N, 98.6°W
|
| 151 |
+
center_lat, center_lon = 39.8, -98.6
|
| 152 |
+
df['dist_from_center'] = np.sqrt(
|
| 153 |
+
(df['LATITUDE'] - center_lat) ** 2 +
|
| 154 |
+
(df['LONGITUDE'] - center_lon) ** 2
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
print(" Created: lat_bin, lon_bin, geo_cluster, lat_squared, lon_squared, lat_lon_interaction, dist_from_center")
|
| 158 |
+
|
| 159 |
+
return train_df, test_df, kmeans
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def create_cyclical_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 163 |
+
"""Create cyclical encoding for periodic features."""
|
| 164 |
+
print("\nCreating cyclical features...")
|
| 165 |
+
|
| 166 |
+
# Cyclical encoding for month (captures January-December continuity)
|
| 167 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 168 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 169 |
+
|
| 170 |
+
# Cyclical encoding for day of year
|
| 171 |
+
df['doy_sin'] = np.sin(2 * np.pi * df['DISCOVERY_DOY'] / 365)
|
| 172 |
+
df['doy_cos'] = np.cos(2 * np.pi * df['DISCOVERY_DOY'] / 365)
|
| 173 |
+
|
| 174 |
+
# Cyclical encoding for day of week
|
| 175 |
+
df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
|
| 176 |
+
df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
|
| 177 |
+
|
| 178 |
+
print(" Created: month_sin/cos, doy_sin/cos, dow_sin/cos")
|
| 179 |
+
|
| 180 |
+
return df
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def create_year_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""Create year-based features."""
|
| 185 |
+
print("\nCreating year features...")
|
| 186 |
+
|
| 187 |
+
# Normalized year (0-1 scale for 1992-2015)
|
| 188 |
+
min_year, max_year = 1992, 2015
|
| 189 |
+
df['year_normalized'] = (df['FIRE_YEAR'] - min_year) / (max_year - min_year)
|
| 190 |
+
|
| 191 |
+
# Years since start
|
| 192 |
+
df['years_since_1992'] = df['FIRE_YEAR'] - min_year
|
| 193 |
+
|
| 194 |
+
print(" Created: year_normalized, years_since_1992")
|
| 195 |
+
|
| 196 |
+
return df
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_feature_columns(df: pd.DataFrame) -> list:
|
| 200 |
+
"""Get list of feature columns for modeling."""
|
| 201 |
+
# Exclude target, original categorical text columns, and intermediate columns
|
| 202 |
+
exclude_cols = [
|
| 203 |
+
TARGET_COLUMN,
|
| 204 |
+
'NWCG_REPORTING_AGENCY', 'STAT_CAUSE_DESCR', 'STATE', 'OWNER_DESCR',
|
| 205 |
+
'COUNTY' # If present
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
| 209 |
+
return feature_cols
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def save_data(train_df: pd.DataFrame, test_df: pd.DataFrame) -> None:
|
| 213 |
+
"""Save feature-engineered data."""
|
| 214 |
+
print("\nSaving feature-engineered data...")
|
| 215 |
+
|
| 216 |
+
# Overwrite train/test files with new features
|
| 217 |
+
train_df.to_parquet(TRAIN_PARQUET, index=False)
|
| 218 |
+
test_df.to_parquet(TEST_PARQUET, index=False)
|
| 219 |
+
|
| 220 |
+
print(f" Train data: {TRAIN_PARQUET}")
|
| 221 |
+
print(f" Test data: {TEST_PARQUET}")
|
| 222 |
+
|
| 223 |
+
# Also save combined for reference
|
| 224 |
+
combined = pd.concat([train_df, test_df], ignore_index=True)
|
| 225 |
+
combined.to_parquet(FEATURES_PARQUET, index=False)
|
| 226 |
+
print(f" Combined data: {FEATURES_PARQUET}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def print_summary(train_df: pd.DataFrame) -> None:
|
| 230 |
+
"""Print feature engineering summary."""
|
| 231 |
+
print("\n" + "="*60)
|
| 232 |
+
print("FEATURE ENGINEERING SUMMARY")
|
| 233 |
+
print("="*60)
|
| 234 |
+
|
| 235 |
+
feature_cols = get_feature_columns(train_df)
|
| 236 |
+
|
| 237 |
+
print(f"\nTotal features: {len(feature_cols)}")
|
| 238 |
+
print("\nFeature list:")
|
| 239 |
+
|
| 240 |
+
# Group features by type
|
| 241 |
+
temporal = [c for c in feature_cols if c in ['month', 'day_of_week', 'is_weekend', 'season', 'is_fire_season',
|
| 242 |
+
'month_sin', 'month_cos', 'doy_sin', 'doy_cos', 'dow_sin', 'dow_cos']]
|
| 243 |
+
geospatial = [c for c in feature_cols if c in ['lat_bin', 'lon_bin', 'geo_cluster', 'lat_squared', 'lon_squared',
|
| 244 |
+
'lat_lon_interaction', 'dist_from_center', 'LATITUDE', 'LONGITUDE']]
|
| 245 |
+
year_feats = [c for c in feature_cols if c in ['FIRE_YEAR', 'year_normalized', 'years_since_1992', 'DISCOVERY_DOY']]
|
| 246 |
+
encoded = [c for c in feature_cols if c.endswith('_encoded')]
|
| 247 |
+
|
| 248 |
+
print(f"\n Temporal ({len(temporal)}): {temporal}")
|
| 249 |
+
print(f"\n Geospatial ({len(geospatial)}): {geospatial}")
|
| 250 |
+
print(f"\n Year-based ({len(year_feats)}): {year_feats}")
|
| 251 |
+
print(f"\n Encoded categorical ({len(encoded)}): {encoded}")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def main():
|
| 255 |
+
"""Main feature engineering pipeline."""
|
| 256 |
+
print("\n" + "="*60)
|
| 257 |
+
print("FEATURE ENGINEERING")
|
| 258 |
+
print("="*60)
|
| 259 |
+
|
| 260 |
+
# Load data
|
| 261 |
+
train_df, test_df = load_data()
|
| 262 |
+
|
| 263 |
+
# Create temporal features
|
| 264 |
+
train_df = create_temporal_features(train_df)
|
| 265 |
+
test_df = create_temporal_features(test_df)
|
| 266 |
+
|
| 267 |
+
# Create geospatial features
|
| 268 |
+
train_df, test_df, kmeans = create_geospatial_features(train_df, test_df)
|
| 269 |
+
|
| 270 |
+
# Create cyclical features
|
| 271 |
+
train_df = create_cyclical_features(train_df)
|
| 272 |
+
test_df = create_cyclical_features(test_df)
|
| 273 |
+
|
| 274 |
+
# Create year features
|
| 275 |
+
train_df = create_year_features(train_df)
|
| 276 |
+
test_df = create_year_features(test_df)
|
| 277 |
+
|
| 278 |
+
# Save data
|
| 279 |
+
save_data(train_df, test_df)
|
| 280 |
+
|
| 281 |
+
# Print summary
|
| 282 |
+
print_summary(train_df)
|
| 283 |
+
|
| 284 |
+
print("\n" + "="*60)
|
| 285 |
+
print("✓ Feature Engineering Complete!")
|
| 286 |
+
print("="*60 + "\n")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
scripts/05_train_model.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 05: Model Training
|
| 3 |
+
|
| 4 |
+
This script trains the ordinal classification model:
|
| 5 |
+
- Uses LightGBM for multi-class ordinal classification
|
| 6 |
+
- Implements class weighting for imbalanced data
|
| 7 |
+
- Performs cross-validation
|
| 8 |
+
- Includes hyperparameter tuning with Optuna
|
| 9 |
+
- Saves the trained model
|
| 10 |
+
|
| 11 |
+
Ordinal Classification Approach:
|
| 12 |
+
Since fire size classes have a natural order (Small < Medium < Large),
|
| 13 |
+
we use ordinal-aware training with cumulative link model concepts.
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python scripts/05_train_model.py [--tune]
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import joblib
|
| 25 |
+
import lightgbm as lgb
|
| 26 |
+
import numpy as np
|
| 27 |
+
import optuna
|
| 28 |
+
import pandas as pd
|
| 29 |
+
from sklearn.metrics import (
|
| 30 |
+
accuracy_score,
|
| 31 |
+
balanced_accuracy_score,
|
| 32 |
+
classification_report,
|
| 33 |
+
cohen_kappa_score,
|
| 34 |
+
f1_score,
|
| 35 |
+
)
|
| 36 |
+
from sklearn.model_selection import StratifiedKFold
|
| 37 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 38 |
+
|
| 39 |
+
# Add project root to path
|
| 40 |
+
project_root = Path(__file__).parent.parent
|
| 41 |
+
sys.path.insert(0, str(project_root))
|
| 42 |
+
|
| 43 |
+
from config.config import (
|
| 44 |
+
TRAIN_PARQUET,
|
| 45 |
+
TEST_PARQUET,
|
| 46 |
+
MODELS_DIR,
|
| 47 |
+
TARGET_COLUMN,
|
| 48 |
+
TARGET_CLASS_NAMES,
|
| 49 |
+
LIGHTGBM_PARAMS,
|
| 50 |
+
OPTUNA_SEARCH_SPACE,
|
| 51 |
+
N_OPTUNA_TRIALS,
|
| 52 |
+
N_FOLDS,
|
| 53 |
+
RANDOM_STATE,
|
| 54 |
+
USE_CLASS_WEIGHTS,
|
| 55 |
+
PRIMARY_METRIC
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_data() -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 60 |
+
"""Load train and test data."""
|
| 61 |
+
print("Loading data...")
|
| 62 |
+
train_df = pd.read_parquet(TRAIN_PARQUET)
|
| 63 |
+
test_df = pd.read_parquet(TEST_PARQUET)
|
| 64 |
+
print(f" Train: {len(train_df):,} rows")
|
| 65 |
+
print(f" Test: {len(test_df):,} rows")
|
| 66 |
+
return train_df, test_df
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_feature_columns(df: pd.DataFrame) -> list:
|
| 70 |
+
"""Get list of feature columns for modeling."""
|
| 71 |
+
exclude_cols = [
|
| 72 |
+
TARGET_COLUMN,
|
| 73 |
+
'NWCG_REPORTING_AGENCY', 'STAT_CAUSE_DESCR', 'STATE', 'OWNER_DESCR',
|
| 74 |
+
'COUNTY'
|
| 75 |
+
]
|
| 76 |
+
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
| 77 |
+
return feature_cols
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def prepare_data(train_df: pd.DataFrame, test_df: pd.DataFrame) -> tuple:
|
| 81 |
+
"""Prepare features and targets for training."""
|
| 82 |
+
print("\nPreparing data...")
|
| 83 |
+
|
| 84 |
+
feature_cols = get_feature_columns(train_df)
|
| 85 |
+
|
| 86 |
+
X_train = train_df[feature_cols].values
|
| 87 |
+
y_train = train_df[TARGET_COLUMN].values
|
| 88 |
+
X_test = test_df[feature_cols].values
|
| 89 |
+
y_test = test_df[TARGET_COLUMN].values
|
| 90 |
+
|
| 91 |
+
print(f" Features: {len(feature_cols)}")
|
| 92 |
+
print(f" Feature columns: {feature_cols}")
|
| 93 |
+
|
| 94 |
+
return X_train, y_train, X_test, y_test, feature_cols
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compute_weights(y_train: np.ndarray) -> np.ndarray:
|
| 98 |
+
"""Compute sample weights for class imbalance."""
|
| 99 |
+
print("\nComputing class weights...")
|
| 100 |
+
|
| 101 |
+
classes = np.unique(y_train)
|
| 102 |
+
class_weights = compute_class_weight(
|
| 103 |
+
class_weight='balanced',
|
| 104 |
+
classes=classes,
|
| 105 |
+
y=y_train
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
weight_dict = dict(zip(classes, class_weights))
|
| 109 |
+
print(f" Class weights: {weight_dict}")
|
| 110 |
+
|
| 111 |
+
# Create sample weights array
|
| 112 |
+
sample_weights = np.array([weight_dict[y] for y in y_train])
|
| 113 |
+
|
| 114 |
+
return sample_weights
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def evaluate_model(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> dict:
|
| 118 |
+
"""Evaluate model predictions."""
|
| 119 |
+
metrics = {
|
| 120 |
+
'accuracy': accuracy_score(y_true, y_pred),
|
| 121 |
+
'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
|
| 122 |
+
'macro_f1': f1_score(y_true, y_pred, average='macro'),
|
| 123 |
+
'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
|
| 124 |
+
'cohen_kappa': cohen_kappa_score(y_true, y_pred, weights='linear') # Linear weights for ordinal
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if prefix:
|
| 128 |
+
print(f"\n{prefix} Metrics:")
|
| 129 |
+
else:
|
| 130 |
+
print("\nMetrics:")
|
| 131 |
+
|
| 132 |
+
for name, value in metrics.items():
|
| 133 |
+
print(f" {name}: {value:.4f}")
|
| 134 |
+
|
| 135 |
+
return metrics
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def cross_validate(X: np.ndarray, y: np.ndarray, params: dict,
|
| 139 |
+
sample_weights: np.ndarray = None) -> tuple[float, float]:
|
| 140 |
+
"""Perform cross-validation and return mean and std of primary metric."""
|
| 141 |
+
|
| 142 |
+
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE)
|
| 143 |
+
scores = []
|
| 144 |
+
|
| 145 |
+
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
|
| 146 |
+
X_fold_train, X_fold_val = X[train_idx], X[val_idx]
|
| 147 |
+
y_fold_train, y_fold_val = y[train_idx], y[val_idx]
|
| 148 |
+
|
| 149 |
+
if sample_weights is not None:
|
| 150 |
+
weights_fold = sample_weights[train_idx]
|
| 151 |
+
else:
|
| 152 |
+
weights_fold = None
|
| 153 |
+
|
| 154 |
+
# Create LightGBM datasets
|
| 155 |
+
train_data = lgb.Dataset(X_fold_train, label=y_fold_train, weight=weights_fold)
|
| 156 |
+
val_data = lgb.Dataset(X_fold_val, label=y_fold_val, reference=train_data)
|
| 157 |
+
|
| 158 |
+
# Train model
|
| 159 |
+
model = lgb.train(
|
| 160 |
+
params,
|
| 161 |
+
train_data,
|
| 162 |
+
num_boost_round=params.get('n_estimators', 500),
|
| 163 |
+
valid_sets=[val_data],
|
| 164 |
+
callbacks=[lgb.early_stopping(50, verbose=False)]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Predict
|
| 168 |
+
y_pred = model.predict(X_fold_val)
|
| 169 |
+
y_pred_class = np.argmax(y_pred, axis=1)
|
| 170 |
+
|
| 171 |
+
# Score
|
| 172 |
+
score = f1_score(y_fold_val, y_pred_class, average='macro')
|
| 173 |
+
scores.append(score)
|
| 174 |
+
|
| 175 |
+
return np.mean(scores), np.std(scores)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def objective(trial: optuna.Trial, X: np.ndarray, y: np.ndarray,
|
| 179 |
+
sample_weights: np.ndarray) -> float:
|
| 180 |
+
"""Optuna objective function for hyperparameter tuning."""
|
| 181 |
+
|
| 182 |
+
params = LIGHTGBM_PARAMS.copy()
|
| 183 |
+
|
| 184 |
+
# Sample hyperparameters
|
| 185 |
+
params['n_estimators'] = trial.suggest_int('n_estimators', *OPTUNA_SEARCH_SPACE['n_estimators'])
|
| 186 |
+
params['max_depth'] = trial.suggest_int('max_depth', *OPTUNA_SEARCH_SPACE['max_depth'])
|
| 187 |
+
params['learning_rate'] = trial.suggest_float('learning_rate', *OPTUNA_SEARCH_SPACE['learning_rate'], log=True)
|
| 188 |
+
params['num_leaves'] = trial.suggest_int('num_leaves', *OPTUNA_SEARCH_SPACE['num_leaves'])
|
| 189 |
+
params['min_child_samples'] = trial.suggest_int('min_child_samples', *OPTUNA_SEARCH_SPACE['min_child_samples'])
|
| 190 |
+
params['subsample'] = trial.suggest_float('subsample', *OPTUNA_SEARCH_SPACE['subsample'])
|
| 191 |
+
params['colsample_bytree'] = trial.suggest_float('colsample_bytree', *OPTUNA_SEARCH_SPACE['colsample_bytree'])
|
| 192 |
+
params['reg_alpha'] = trial.suggest_float('reg_alpha', *OPTUNA_SEARCH_SPACE['reg_alpha'])
|
| 193 |
+
params['reg_lambda'] = trial.suggest_float('reg_lambda', *OPTUNA_SEARCH_SPACE['reg_lambda'])
|
| 194 |
+
|
| 195 |
+
# Cross-validate
|
| 196 |
+
mean_score, _ = cross_validate(X, y, params, sample_weights)
|
| 197 |
+
|
| 198 |
+
return mean_score
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def tune_hyperparameters(X: np.ndarray, y: np.ndarray,
|
| 202 |
+
sample_weights: np.ndarray) -> dict:
|
| 203 |
+
"""Tune hyperparameters using Optuna."""
|
| 204 |
+
print("\n" + "="*60)
|
| 205 |
+
print("HYPERPARAMETER TUNING")
|
| 206 |
+
print("="*60)
|
| 207 |
+
|
| 208 |
+
print(f"\nRunning {N_OPTUNA_TRIALS} Optuna trials...")
|
| 209 |
+
|
| 210 |
+
# Create study
|
| 211 |
+
study = optuna.create_study(
|
| 212 |
+
direction='maximize',
|
| 213 |
+
sampler=optuna.samplers.TPESampler(seed=RANDOM_STATE)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Optimize
|
| 217 |
+
study.optimize(
|
| 218 |
+
lambda trial: objective(trial, X, y, sample_weights),
|
| 219 |
+
n_trials=N_OPTUNA_TRIALS,
|
| 220 |
+
show_progress_bar=True
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
print(f"\nBest trial:")
|
| 224 |
+
print(f" Value (macro F1): {study.best_trial.value:.4f}")
|
| 225 |
+
print(f" Params: {study.best_trial.params}")
|
| 226 |
+
|
| 227 |
+
# Merge best params with base params
|
| 228 |
+
best_params = LIGHTGBM_PARAMS.copy()
|
| 229 |
+
best_params.update(study.best_trial.params)
|
| 230 |
+
|
| 231 |
+
best_params_path = MODELS_DIR / 'best_params.json'
|
| 232 |
+
with open(best_params_path, 'w') as f:
|
| 233 |
+
json.dump(study.best_trial.params, f)
|
| 234 |
+
print(f" Best params saved: {best_params_path}")
|
| 235 |
+
|
| 236 |
+
return best_params
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def train_final_model(X_train: np.ndarray, y_train: np.ndarray,
|
| 240 |
+
X_test: np.ndarray, y_test: np.ndarray,
|
| 241 |
+
params: dict, sample_weights: np.ndarray,
|
| 242 |
+
feature_names: list) -> lgb.Booster:
|
| 243 |
+
"""Train final model on full training data."""
|
| 244 |
+
print("\n" + "="*60)
|
| 245 |
+
print("TRAINING FINAL MODEL")
|
| 246 |
+
print("="*60)
|
| 247 |
+
|
| 248 |
+
# Create datasets
|
| 249 |
+
train_data = lgb.Dataset(X_train, label=y_train, weight=sample_weights,
|
| 250 |
+
feature_name=feature_names)
|
| 251 |
+
val_data = lgb.Dataset(X_test, label=y_test, reference=train_data,
|
| 252 |
+
feature_name=feature_names)
|
| 253 |
+
|
| 254 |
+
# Train
|
| 255 |
+
print("\nTraining...")
|
| 256 |
+
model = lgb.train(
|
| 257 |
+
params,
|
| 258 |
+
train_data,
|
| 259 |
+
num_boost_round=params.get('n_estimators', 2000),
|
| 260 |
+
valid_sets=[train_data, val_data],
|
| 261 |
+
valid_names=['train', 'test'],
|
| 262 |
+
callbacks=[
|
| 263 |
+
lgb.early_stopping(50, verbose=True),
|
| 264 |
+
lgb.log_evaluation(period=50)
|
| 265 |
+
]
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Evaluate
|
| 269 |
+
print("\n" + "-"*40)
|
| 270 |
+
|
| 271 |
+
# Train predictions
|
| 272 |
+
y_train_pred = np.argmax(model.predict(X_train), axis=1)
|
| 273 |
+
evaluate_model(y_train, y_train_pred, "Train")
|
| 274 |
+
|
| 275 |
+
# Test predictions
|
| 276 |
+
y_test_pred = np.argmax(model.predict(X_test), axis=1)
|
| 277 |
+
test_metrics = evaluate_model(y_test, y_test_pred, "Test")
|
| 278 |
+
|
| 279 |
+
# Classification report
|
| 280 |
+
print("\nClassification Report (Test):")
|
| 281 |
+
print(classification_report(y_test, y_test_pred, target_names=TARGET_CLASS_NAMES))
|
| 282 |
+
|
| 283 |
+
return model, test_metrics
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def save_model(model: lgb.Booster, params: dict, feature_names: list, metrics: dict) -> None:
|
| 287 |
+
"""Save trained model and metadata."""
|
| 288 |
+
print("\nSaving model...")
|
| 289 |
+
|
| 290 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 291 |
+
|
| 292 |
+
# Save LightGBM model
|
| 293 |
+
model_path = MODELS_DIR / 'wildfire_model.txt'
|
| 294 |
+
model.save_model(str(model_path))
|
| 295 |
+
print(f" Model: {model_path}")
|
| 296 |
+
|
| 297 |
+
# Save metadata
|
| 298 |
+
metadata = {
|
| 299 |
+
'params': params,
|
| 300 |
+
'feature_names': feature_names,
|
| 301 |
+
'metrics': metrics,
|
| 302 |
+
'target_classes': TARGET_CLASS_NAMES
|
| 303 |
+
}
|
| 304 |
+
metadata_path = MODELS_DIR / 'model_metadata.joblib'
|
| 305 |
+
joblib.dump(metadata, metadata_path)
|
| 306 |
+
print(f" Metadata: {metadata_path}")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def main():
|
| 310 |
+
"""Main training pipeline."""
|
| 311 |
+
# Parse arguments
|
| 312 |
+
parser = argparse.ArgumentParser(description='Train wildfire classification model')
|
| 313 |
+
parser.add_argument('--tune', action='store_true', help='Run hyperparameter tuning')
|
| 314 |
+
args = parser.parse_args()
|
| 315 |
+
|
| 316 |
+
print("\n" + "="*60)
|
| 317 |
+
print("MODEL TRAINING")
|
| 318 |
+
print("="*60)
|
| 319 |
+
|
| 320 |
+
# Load data
|
| 321 |
+
train_df, test_df = load_data()
|
| 322 |
+
|
| 323 |
+
# Prepare data
|
| 324 |
+
X_train, y_train, X_test, y_test, feature_cols = prepare_data(train_df, test_df)
|
| 325 |
+
|
| 326 |
+
# Compute class weights
|
| 327 |
+
sample_weights = None
|
| 328 |
+
if USE_CLASS_WEIGHTS:
|
| 329 |
+
sample_weights = compute_weights(y_train)
|
| 330 |
+
|
| 331 |
+
# Get parameters
|
| 332 |
+
if args.tune:
|
| 333 |
+
params = tune_hyperparameters(X_train, y_train, sample_weights)
|
| 334 |
+
else:
|
| 335 |
+
best_params_path = MODELS_DIR / 'best_params.json'
|
| 336 |
+
if best_params_path.exists():
|
| 337 |
+
# Load saved best params
|
| 338 |
+
with open(best_params_path, 'r') as f:
|
| 339 |
+
tuned_params = json.load(f)
|
| 340 |
+
params = LIGHTGBM_PARAMS.copy()
|
| 341 |
+
params.update(tuned_params)
|
| 342 |
+
print(f"Loaded best params from {best_params_path}")
|
| 343 |
+
else:
|
| 344 |
+
# Fallback to defaults
|
| 345 |
+
params = LIGHTGBM_PARAMS.copy()
|
| 346 |
+
params['n_estimators'] = 500
|
| 347 |
+
params['max_depth'] = 8
|
| 348 |
+
params['learning_rate'] = 0.05
|
| 349 |
+
params['num_leaves'] = 64
|
| 350 |
+
params['min_child_samples'] = 50
|
| 351 |
+
params['subsample'] = 0.8
|
| 352 |
+
params['colsample_bytree'] = 0.8
|
| 353 |
+
print("No saved params found; using defaults")
|
| 354 |
+
|
| 355 |
+
# Train final model
|
| 356 |
+
model, metrics = train_final_model(
|
| 357 |
+
X_train, y_train, X_test, y_test,
|
| 358 |
+
params, sample_weights, feature_cols
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Save model
|
| 362 |
+
save_model(model, params, feature_cols, metrics)
|
| 363 |
+
|
| 364 |
+
print("\n" + "="*60)
|
| 365 |
+
print("✓ Training Complete!")
|
| 366 |
+
print("="*60 + "\n")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
main()
|
scripts/06_evaluate.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 06: Model Evaluation
|
| 3 |
+
|
| 4 |
+
This script performs comprehensive evaluation of the trained model:
|
| 5 |
+
- Confusion matrix visualization
|
| 6 |
+
- Per-class metrics analysis
|
| 7 |
+
- Ordinal-specific metrics (linear weighted kappa)
|
| 8 |
+
- SHAP feature importance analysis
|
| 9 |
+
- Error analysis
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python scripts/06_evaluate.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import joblib
|
| 19 |
+
import lightgbm as lgb
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
import shap
|
| 25 |
+
from sklearn.metrics import (
|
| 26 |
+
accuracy_score,
|
| 27 |
+
balanced_accuracy_score,
|
| 28 |
+
classification_report,
|
| 29 |
+
cohen_kappa_score,
|
| 30 |
+
confusion_matrix,
|
| 31 |
+
f1_score,
|
| 32 |
+
precision_score,
|
| 33 |
+
recall_score,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Add project root to path
|
| 37 |
+
project_root = Path(__file__).parent.parent
|
| 38 |
+
sys.path.insert(0, str(project_root))
|
| 39 |
+
|
| 40 |
+
from config.config import (
|
| 41 |
+
TEST_PARQUET,
|
| 42 |
+
TRAIN_PARQUET,
|
| 43 |
+
MODELS_DIR,
|
| 44 |
+
FIGURES_DIR,
|
| 45 |
+
TARGET_COLUMN,
|
| 46 |
+
TARGET_CLASS_NAMES
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Set style
|
| 50 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_model_and_data() -> tuple:
|
| 54 |
+
"""Load trained model, metadata, and test data."""
|
| 55 |
+
print("Loading model and data...")
|
| 56 |
+
|
| 57 |
+
# Load model
|
| 58 |
+
model_path = MODELS_DIR / 'wildfire_model.txt'
|
| 59 |
+
model = lgb.Booster(model_file=str(model_path))
|
| 60 |
+
print(f" Model: {model_path}")
|
| 61 |
+
|
| 62 |
+
# Load metadata
|
| 63 |
+
metadata_path = MODELS_DIR / 'model_metadata.joblib'
|
| 64 |
+
metadata = joblib.load(metadata_path)
|
| 65 |
+
print(f" Metadata: {metadata_path}")
|
| 66 |
+
|
| 67 |
+
# Load test data
|
| 68 |
+
test_df = pd.read_parquet(TEST_PARQUET)
|
| 69 |
+
train_df = pd.read_parquet(TRAIN_PARQUET)
|
| 70 |
+
print(f" Test data: {len(test_df):,} rows")
|
| 71 |
+
|
| 72 |
+
return model, metadata, train_df, test_df
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def prepare_data(df: pd.DataFrame, feature_names: list) -> tuple:
|
| 76 |
+
"""Prepare features and target from dataframe."""
|
| 77 |
+
X = df[feature_names].values
|
| 78 |
+
y = df[TARGET_COLUMN].values
|
| 79 |
+
return X, y
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def compute_all_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> dict:
|
| 83 |
+
"""Compute comprehensive metrics."""
|
| 84 |
+
|
| 85 |
+
metrics = {
|
| 86 |
+
# Standard metrics
|
| 87 |
+
'accuracy': accuracy_score(y_true, y_pred),
|
| 88 |
+
'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
|
| 89 |
+
'macro_f1': f1_score(y_true, y_pred, average='macro'),
|
| 90 |
+
'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
|
| 91 |
+
'macro_precision': precision_score(y_true, y_pred, average='macro'),
|
| 92 |
+
'macro_recall': recall_score(y_true, y_pred, average='macro'),
|
| 93 |
+
|
| 94 |
+
# Ordinal-specific: Linear weighted Cohen's Kappa
|
| 95 |
+
# Penalizes predictions farther from true class
|
| 96 |
+
'cohen_kappa_linear': cohen_kappa_score(y_true, y_pred, weights='linear'),
|
| 97 |
+
'cohen_kappa_quadratic': cohen_kappa_score(y_true, y_pred, weights='quadratic'),
|
| 98 |
+
|
| 99 |
+
# Per-class metrics
|
| 100 |
+
'per_class_precision': precision_score(y_true, y_pred, average=None),
|
| 101 |
+
'per_class_recall': recall_score(y_true, y_pred, average=None),
|
| 102 |
+
'per_class_f1': f1_score(y_true, y_pred, average=None)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return metrics
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def print_metrics(metrics: dict) -> None:
|
| 109 |
+
"""Print metrics in a formatted way."""
|
| 110 |
+
print("\n" + "="*60)
|
| 111 |
+
print("EVALUATION METRICS")
|
| 112 |
+
print("="*60)
|
| 113 |
+
|
| 114 |
+
print("\nOverall Metrics:")
|
| 115 |
+
print(f" Accuracy: {metrics['accuracy']:.4f}")
|
| 116 |
+
print(f" Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
|
| 117 |
+
print(f" Macro F1: {metrics['macro_f1']:.4f}")
|
| 118 |
+
print(f" Weighted F1: {metrics['weighted_f1']:.4f}")
|
| 119 |
+
print(f" Macro Precision: {metrics['macro_precision']:.4f}")
|
| 120 |
+
print(f" Macro Recall: {metrics['macro_recall']:.4f}")
|
| 121 |
+
|
| 122 |
+
print("\nOrdinal Metrics (penalize distance from true class):")
|
| 123 |
+
print(f" Cohen's Kappa (Linear): {metrics['cohen_kappa_linear']:.4f}")
|
| 124 |
+
print(f" Cohen's Kappa (Quadratic): {metrics['cohen_kappa_quadratic']:.4f}")
|
| 125 |
+
|
| 126 |
+
print("\nPer-Class Metrics:")
|
| 127 |
+
print(f" {'Class':<10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
| 128 |
+
print(f" {'-'*40}")
|
| 129 |
+
for i, name in enumerate(TARGET_CLASS_NAMES):
|
| 130 |
+
print(f" {name:<10} {metrics['per_class_precision'][i]:>10.4f} "
|
| 131 |
+
f"{metrics['per_class_recall'][i]:>10.4f} {metrics['per_class_f1'][i]:>10.4f}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path) -> None:
|
| 135 |
+
"""Plot and save confusion matrix."""
|
| 136 |
+
print("\nGenerating confusion matrix...")
|
| 137 |
+
|
| 138 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 139 |
+
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
| 140 |
+
|
| 141 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 142 |
+
|
| 143 |
+
# Raw counts
|
| 144 |
+
ax1 = axes[0]
|
| 145 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
|
| 146 |
+
xticklabels=TARGET_CLASS_NAMES, yticklabels=TARGET_CLASS_NAMES)
|
| 147 |
+
ax1.set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
|
| 148 |
+
ax1.set_xlabel('Predicted')
|
| 149 |
+
ax1.set_ylabel('Actual')
|
| 150 |
+
|
| 151 |
+
# Normalized (percentages)
|
| 152 |
+
ax2 = axes[1]
|
| 153 |
+
sns.heatmap(cm_normalized, annot=True, fmt='.1%', cmap='Blues', ax=ax2,
|
| 154 |
+
xticklabels=TARGET_CLASS_NAMES, yticklabels=TARGET_CLASS_NAMES)
|
| 155 |
+
ax2.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
|
| 156 |
+
ax2.set_xlabel('Predicted')
|
| 157 |
+
ax2.set_ylabel('Actual')
|
| 158 |
+
|
| 159 |
+
plt.tight_layout()
|
| 160 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 161 |
+
plt.close()
|
| 162 |
+
print(f" Saved: {save_path}")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def plot_classification_report(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path) -> None:
|
| 166 |
+
"""Plot classification metrics as bar chart."""
|
| 167 |
+
print("\nGenerating classification report plot...")
|
| 168 |
+
|
| 169 |
+
report = classification_report(y_true, y_pred, target_names=TARGET_CLASS_NAMES, output_dict=True)
|
| 170 |
+
|
| 171 |
+
# Convert to DataFrame
|
| 172 |
+
df_report = pd.DataFrame(report).T
|
| 173 |
+
df_report = df_report.drop(['accuracy', 'macro avg', 'weighted avg'], errors='ignore')
|
| 174 |
+
df_report = df_report[['precision', 'recall', 'f1-score']]
|
| 175 |
+
|
| 176 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 177 |
+
|
| 178 |
+
x = np.arange(len(TARGET_CLASS_NAMES))
|
| 179 |
+
width = 0.25
|
| 180 |
+
|
| 181 |
+
bars1 = ax.bar(x - width, df_report['precision'], width, label='Precision', color='#3498db')
|
| 182 |
+
bars2 = ax.bar(x, df_report['recall'], width, label='Recall', color='#2ecc71')
|
| 183 |
+
bars3 = ax.bar(x + width, df_report['f1-score'], width, label='F1-Score', color='#e74c3c')
|
| 184 |
+
|
| 185 |
+
ax.set_xlabel('Fire Size Class')
|
| 186 |
+
ax.set_ylabel('Score')
|
| 187 |
+
ax.set_title('Per-Class Classification Metrics', fontsize=14, fontweight='bold')
|
| 188 |
+
ax.set_xticks(x)
|
| 189 |
+
ax.set_xticklabels(TARGET_CLASS_NAMES)
|
| 190 |
+
ax.legend()
|
| 191 |
+
ax.set_ylim(0, 1.1)
|
| 192 |
+
|
| 193 |
+
# Add value labels
|
| 194 |
+
for bars in [bars1, bars2, bars3]:
|
| 195 |
+
for bar in bars:
|
| 196 |
+
height = bar.get_height()
|
| 197 |
+
ax.annotate(f'{height:.2f}',
|
| 198 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 199 |
+
xytext=(0, 3), textcoords="offset points",
|
| 200 |
+
ha='center', va='bottom', fontsize=8)
|
| 201 |
+
|
| 202 |
+
plt.tight_layout()
|
| 203 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 204 |
+
plt.close()
|
| 205 |
+
print(f" Saved: {save_path}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def plot_shap_importance(model: lgb.Booster, X: np.ndarray,
|
| 209 |
+
feature_names: list, save_path: Path,
|
| 210 |
+
max_display: int = 20) -> None:
|
| 211 |
+
"""Generate SHAP feature importance plots."""
|
| 212 |
+
print("\nGenerating SHAP analysis...")
|
| 213 |
+
print(f" X shape: {X.shape}")
|
| 214 |
+
print(f" Number of feature names: {len(feature_names)}")
|
| 215 |
+
|
| 216 |
+
# Use a sample for SHAP (faster computation)
|
| 217 |
+
sample_size = min(5000, len(X))
|
| 218 |
+
np.random.seed(42)
|
| 219 |
+
sample_idx = np.random.choice(len(X), sample_size, replace=False)
|
| 220 |
+
X_sample = X[sample_idx]
|
| 221 |
+
|
| 222 |
+
# Create explainer
|
| 223 |
+
explainer = shap.TreeExplainer(model)
|
| 224 |
+
shap_values = explainer.shap_values(X_sample)
|
| 225 |
+
|
| 226 |
+
# SHAP values is a list of arrays (one per class for multiclass)
|
| 227 |
+
# Average absolute SHAP values across all classes
|
| 228 |
+
if isinstance(shap_values, list):
|
| 229 |
+
# If it's a list of arrays, each array is (samples, features)
|
| 230 |
+
# We want the mean absolute value for each feature across all samples and all classes
|
| 231 |
+
mean_shap = np.mean([np.abs(sv).mean(axis=0) for sv in shap_values], axis=0)
|
| 232 |
+
else:
|
| 233 |
+
# Handle case where shap_values is a single array (samples, features * classes)
|
| 234 |
+
# or (samples, features)
|
| 235 |
+
mean_shap = np.abs(shap_values).mean(axis=0)
|
| 236 |
+
|
| 237 |
+
# If we have a multiple of features, it's likely multiclass flattened
|
| 238 |
+
num_feats = len(feature_names)
|
| 239 |
+
if mean_shap.size > num_feats and mean_shap.size % num_feats == 0:
|
| 240 |
+
n_classes = mean_shap.size // num_feats
|
| 241 |
+
print(f" Aggregating SHAP values for {n_classes} classes...")
|
| 242 |
+
mean_shap = mean_shap.reshape(n_classes, num_feats).mean(axis=0)
|
| 243 |
+
|
| 244 |
+
# Ensure mean_shap is 1D
|
| 245 |
+
if mean_shap.ndim > 1:
|
| 246 |
+
mean_shap = mean_shap.flatten()
|
| 247 |
+
|
| 248 |
+
print(f" Mean SHAP shape: {mean_shap.shape}")
|
| 249 |
+
|
| 250 |
+
# Handle mismatch between feature_names and mean_shap length
|
| 251 |
+
if len(feature_names) != mean_shap.size:
|
| 252 |
+
print(f" WARNING: Feature names ({len(feature_names)}) != SHAP values ({mean_shap.size})")
|
| 253 |
+
# Trim to match
|
| 254 |
+
n = min(len(feature_names), mean_shap.size)
|
| 255 |
+
feature_names = feature_names[:n]
|
| 256 |
+
mean_shap = mean_shap[:n]
|
| 257 |
+
print(f" Trimmed to {n} features")
|
| 258 |
+
|
| 259 |
+
# Create feature importance DataFrame
|
| 260 |
+
importance_df = pd.DataFrame({
|
| 261 |
+
'feature': feature_names,
|
| 262 |
+
'importance': mean_shap
|
| 263 |
+
}).sort_values('importance', ascending=True)
|
| 264 |
+
|
| 265 |
+
# Plot 1: Feature Importance Bar Chart
|
| 266 |
+
plt.figure(figsize=(10, 8))
|
| 267 |
+
top_features = importance_df.tail(max_display)
|
| 268 |
+
plt.barh(top_features['feature'], top_features['importance'], color='steelblue')
|
| 269 |
+
plt.xlabel('Mean |SHAP Value|')
|
| 270 |
+
plt.title(f'Top {max_display} Feature Importance (SHAP)', fontsize=14, fontweight='bold')
|
| 271 |
+
plt.grid(axis='x', alpha=0.3)
|
| 272 |
+
plt.tight_layout()
|
| 273 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 274 |
+
plt.close()
|
| 275 |
+
print(f" Saved importance plot: {save_path}")
|
| 276 |
+
|
| 277 |
+
# Plot 2: SHAP Summary Plot (Large Fires)
|
| 278 |
+
# Extract SHAP values for Large fire class (class index 2)
|
| 279 |
+
shap_values_large = None
|
| 280 |
+
num_feats = len(feature_names)
|
| 281 |
+
|
| 282 |
+
if isinstance(shap_values, list) and len(shap_values) > 2:
|
| 283 |
+
# Already a list of arrays per class
|
| 284 |
+
shap_values_large = shap_values[2]
|
| 285 |
+
elif isinstance(shap_values, np.ndarray):
|
| 286 |
+
# Single array - need to reshape if it's (samples, features * classes)
|
| 287 |
+
if shap_values.shape[1] == num_feats * 3:
|
| 288 |
+
# Reshape from (samples, features*classes) to (samples, classes, features)
|
| 289 |
+
# Then extract class 2 (Large fires)
|
| 290 |
+
reshaped = shap_values.reshape(shap_values.shape[0], 3, num_feats)
|
| 291 |
+
shap_values_large = reshaped[:, 2, :] # Class 2 = Large
|
| 292 |
+
print(f" Extracted Large fire SHAP values: {shap_values_large.shape}")
|
| 293 |
+
elif shap_values.shape[1] == num_feats:
|
| 294 |
+
# Binary or single output - use as-is
|
| 295 |
+
shap_values_large = shap_values
|
| 296 |
+
|
| 297 |
+
if shap_values_large is not None:
|
| 298 |
+
summary_path = save_path.parent / f"{save_path.stem}_summary{save_path.suffix}"
|
| 299 |
+
plt.figure(figsize=(10, 8))
|
| 300 |
+
try:
|
| 301 |
+
print(" Generating SHAP summary plot...")
|
| 302 |
+
shap.summary_plot(shap_values_large, X_sample, feature_names=feature_names,
|
| 303 |
+
max_display=max_display, show=False)
|
| 304 |
+
plt.title('SHAP Summary: Large Fire Class', fontsize=14, fontweight='bold')
|
| 305 |
+
plt.tight_layout()
|
| 306 |
+
plt.savefig(summary_path, dpi=150, bbox_inches='tight')
|
| 307 |
+
print(f" Saved summary plot: {summary_path}")
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(f" Could not generate summary plot: {e}")
|
| 310 |
+
plt.close()
|
| 311 |
+
else:
|
| 312 |
+
print(" Skipping summary plot (could not extract Large class SHAP values)")
|
| 313 |
+
|
| 314 |
+
# Print top features
|
| 315 |
+
print("\n Top 10 Most Important Features:")
|
| 316 |
+
for _, row in importance_df.tail(10).iloc[::-1].iterrows():
|
| 317 |
+
print(f" {row['feature']}: {row['importance']:.4f}")
|
| 318 |
+
|
| 319 |
+
return importance_df
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def analyze_errors(test_df: pd.DataFrame, y_true: np.ndarray,
|
| 323 |
+
y_pred: np.ndarray, save_path: Path) -> None:
|
| 324 |
+
"""Analyze misclassifications."""
|
| 325 |
+
print("\nAnalyzing misclassifications...")
|
| 326 |
+
|
| 327 |
+
# Add predictions to dataframe
|
| 328 |
+
test_df = test_df.copy()
|
| 329 |
+
test_df['predicted'] = y_pred
|
| 330 |
+
test_df['correct'] = y_true == y_pred
|
| 331 |
+
|
| 332 |
+
errors = test_df[~test_df['correct']]
|
| 333 |
+
|
| 334 |
+
print(f"\n Total errors: {len(errors):,} ({len(errors)/len(test_df)*100:.1f}%)")
|
| 335 |
+
|
| 336 |
+
# Error types
|
| 337 |
+
print("\n Error Distribution:")
|
| 338 |
+
for true_class in range(3):
|
| 339 |
+
for pred_class in range(3):
|
| 340 |
+
if true_class != pred_class:
|
| 341 |
+
count = ((y_true == true_class) & (y_pred == pred_class)).sum()
|
| 342 |
+
if count > 0:
|
| 343 |
+
pct = count / len(errors) * 100
|
| 344 |
+
true_name = TARGET_CLASS_NAMES[true_class]
|
| 345 |
+
pred_name = TARGET_CLASS_NAMES[pred_class]
|
| 346 |
+
print(f" {true_name} → {pred_name}: {count:,} ({pct:.1f}%)")
|
| 347 |
+
|
| 348 |
+
# Adjacent vs non-adjacent errors (important for ordinal)
|
| 349 |
+
adjacent_errors = 0
|
| 350 |
+
non_adjacent_errors = 0
|
| 351 |
+
|
| 352 |
+
for true_class, pred_class in zip(y_true[y_true != y_pred], y_pred[y_true != y_pred]):
|
| 353 |
+
if abs(true_class - pred_class) == 1:
|
| 354 |
+
adjacent_errors += 1
|
| 355 |
+
else:
|
| 356 |
+
non_adjacent_errors += 1
|
| 357 |
+
|
| 358 |
+
print(f"\n Ordinal Error Analysis:")
|
| 359 |
+
print(f" Adjacent errors (off by 1): {adjacent_errors:,} ({adjacent_errors/len(errors)*100:.1f}%)")
|
| 360 |
+
print(f" Non-adjacent errors (off by 2): {non_adjacent_errors:,} ({non_adjacent_errors/len(errors)*100:.1f}%)")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def plot_prediction_distribution(y_true: np.ndarray, y_pred: np.ndarray,
|
| 364 |
+
y_proba: np.ndarray, save_path: Path) -> None:
|
| 365 |
+
"""Plot prediction probability distributions."""
|
| 366 |
+
print("\nGenerating prediction distribution plots...")
|
| 367 |
+
|
| 368 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
|
| 369 |
+
|
| 370 |
+
for i, (ax, class_name) in enumerate(zip(axes, TARGET_CLASS_NAMES)):
|
| 371 |
+
# Get probabilities for this class
|
| 372 |
+
proba = y_proba[:, i]
|
| 373 |
+
|
| 374 |
+
# Split by actual class
|
| 375 |
+
for true_class in range(3):
|
| 376 |
+
mask = y_true == true_class
|
| 377 |
+
ax.hist(proba[mask], bins=50, alpha=0.5,
|
| 378 |
+
label=f'Actual: {TARGET_CLASS_NAMES[true_class]}', density=True)
|
| 379 |
+
|
| 380 |
+
ax.set_xlabel(f'P({class_name})')
|
| 381 |
+
ax.set_ylabel('Density')
|
| 382 |
+
ax.set_title(f'Predicted Probability: {class_name}', fontweight='bold')
|
| 383 |
+
ax.legend(fontsize=8)
|
| 384 |
+
|
| 385 |
+
plt.tight_layout()
|
| 386 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 387 |
+
plt.close()
|
| 388 |
+
print(f" Saved: {save_path}")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def main():
|
| 392 |
+
"""Main evaluation pipeline."""
|
| 393 |
+
print("\n" + "="*60)
|
| 394 |
+
print("MODEL EVALUATION")
|
| 395 |
+
print("="*60)
|
| 396 |
+
|
| 397 |
+
# Create figures directory
|
| 398 |
+
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
| 399 |
+
|
| 400 |
+
# Load model and data
|
| 401 |
+
model, metadata, train_df, test_df = load_model_and_data()
|
| 402 |
+
feature_names = metadata['feature_names']
|
| 403 |
+
|
| 404 |
+
# Prepare data
|
| 405 |
+
X_test, y_test = prepare_data(test_df, feature_names)
|
| 406 |
+
X_train, y_train = prepare_data(train_df, feature_names)
|
| 407 |
+
|
| 408 |
+
# Make predictions
|
| 409 |
+
y_proba = model.predict(X_test)
|
| 410 |
+
y_pred = np.argmax(y_proba, axis=1)
|
| 411 |
+
|
| 412 |
+
# Compute metrics
|
| 413 |
+
metrics = compute_all_metrics(y_test, y_pred, y_proba)
|
| 414 |
+
print_metrics(metrics)
|
| 415 |
+
|
| 416 |
+
# Generate plots
|
| 417 |
+
plot_confusion_matrix(y_test, y_pred, FIGURES_DIR / 'confusion_matrix.png')
|
| 418 |
+
plot_classification_report(y_test, y_pred, FIGURES_DIR / 'classification_metrics.png')
|
| 419 |
+
plot_prediction_distribution(y_test, y_pred, y_proba, FIGURES_DIR / 'prediction_distribution.png')
|
| 420 |
+
|
| 421 |
+
# SHAP analysis
|
| 422 |
+
importance_df = plot_shap_importance(model, X_test, feature_names,
|
| 423 |
+
FIGURES_DIR / 'shap_importance.png')
|
| 424 |
+
|
| 425 |
+
# Error analysis
|
| 426 |
+
analyze_errors(test_df, y_test, y_pred, FIGURES_DIR / 'error_analysis.png')
|
| 427 |
+
|
| 428 |
+
# Save importance rankings
|
| 429 |
+
importance_df.to_csv(FIGURES_DIR / 'feature_importance.csv', index=False)
|
| 430 |
+
|
| 431 |
+
print("\n" + "="*60)
|
| 432 |
+
print("✓ Evaluation Complete!")
|
| 433 |
+
print(f" Figures saved to: {FIGURES_DIR}")
|
| 434 |
+
print("="*60 + "\n")
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
if __name__ == "__main__":
|
| 438 |
+
main()
|
scripts/07_predict.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script 07: Prediction Pipeline
|
| 3 |
+
|
| 4 |
+
This script provides inference capabilities:
|
| 5 |
+
- Load trained model
|
| 6 |
+
- Preprocess new data
|
| 7 |
+
- Generate predictions with probabilities
|
| 8 |
+
- Can be used as a module or standalone script
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# Single prediction
|
| 12 |
+
python scripts/07_predict.py --lat 34.05 --lon -118.24 --state CA --cause "Debris Burning" --month 7
|
| 13 |
+
|
| 14 |
+
# Batch prediction from CSV
|
| 15 |
+
python scripts/07_predict.py --input new_fires.csv --output predictions.csv
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import joblib
|
| 24 |
+
import lightgbm as lgb
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
|
| 28 |
+
# Add project root to path
|
| 29 |
+
project_root = Path(__file__).parent.parent
|
| 30 |
+
sys.path.insert(0, str(project_root))
|
| 31 |
+
|
| 32 |
+
from config.config import (
|
| 33 |
+
MODELS_DIR,
|
| 34 |
+
TARGET_CLASS_NAMES,
|
| 35 |
+
FIRE_SIZE_CLASS_MAPPING,
|
| 36 |
+
CATEGORICAL_FEATURES,
|
| 37 |
+
N_GEO_CLUSTERS,
|
| 38 |
+
LAT_BINS,
|
| 39 |
+
LON_BINS
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class WildfirePredictor:
|
| 44 |
+
"""Wildfire size class predictor."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, model_dir: Path = MODELS_DIR):
|
| 47 |
+
"""Initialize predictor with trained model."""
|
| 48 |
+
self.model_dir = model_dir
|
| 49 |
+
self.model = None
|
| 50 |
+
self.metadata = None
|
| 51 |
+
self.feature_names = None
|
| 52 |
+
self.encoders = {}
|
| 53 |
+
|
| 54 |
+
self._load_model()
|
| 55 |
+
|
| 56 |
+
def _load_model(self) -> None:
|
| 57 |
+
"""Load trained model and metadata."""
|
| 58 |
+
model_path = self.model_dir / 'wildfire_model.txt'
|
| 59 |
+
metadata_path = self.model_dir / 'model_metadata.joblib'
|
| 60 |
+
|
| 61 |
+
if not model_path.exists():
|
| 62 |
+
raise FileNotFoundError(f"Model not found at {model_path}. Run training first.")
|
| 63 |
+
|
| 64 |
+
self.model = lgb.Booster(model_file=str(model_path))
|
| 65 |
+
self.metadata = joblib.load(metadata_path)
|
| 66 |
+
self.feature_names = self.metadata['feature_names']
|
| 67 |
+
|
| 68 |
+
print(f"Loaded model with {len(self.feature_names)} features")
|
| 69 |
+
|
| 70 |
+
def _create_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 71 |
+
"""Create features for prediction."""
|
| 72 |
+
df = df.copy()
|
| 73 |
+
|
| 74 |
+
# Ensure required columns exist
|
| 75 |
+
required = ['LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY']
|
| 76 |
+
for col in required:
|
| 77 |
+
if col not in df.columns:
|
| 78 |
+
raise ValueError(f"Missing required column: {col}")
|
| 79 |
+
|
| 80 |
+
# Temporal features
|
| 81 |
+
reference_year = 2001
|
| 82 |
+
df['temp_date'] = pd.to_datetime(
|
| 83 |
+
df['DISCOVERY_DOY'].astype(int).astype(str) + f'-{reference_year}',
|
| 84 |
+
format='%j-%Y',
|
| 85 |
+
errors='coerce'
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
df['month'] = df['temp_date'].dt.month
|
| 89 |
+
df['day_of_week'] = df['temp_date'].dt.dayofweek
|
| 90 |
+
df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
|
| 91 |
+
df['season'] = df['month'].apply(lambda m:
|
| 92 |
+
1 if m in [12, 1, 2] else
|
| 93 |
+
2 if m in [3, 4, 5] else
|
| 94 |
+
3 if m in [6, 7, 8] else 4
|
| 95 |
+
)
|
| 96 |
+
df['is_fire_season'] = df['month'].isin([6, 7, 8, 9, 10]).astype(int)
|
| 97 |
+
|
| 98 |
+
# Cyclical features
|
| 99 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 100 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 101 |
+
df['doy_sin'] = np.sin(2 * np.pi * df['DISCOVERY_DOY'] / 365)
|
| 102 |
+
df['doy_cos'] = np.cos(2 * np.pi * df['DISCOVERY_DOY'] / 365)
|
| 103 |
+
df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
|
| 104 |
+
df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
|
| 105 |
+
|
| 106 |
+
# Year features
|
| 107 |
+
min_year, max_year = 1992, 2015
|
| 108 |
+
df['year_normalized'] = (df['FIRE_YEAR'] - min_year) / (max_year - min_year)
|
| 109 |
+
df['years_since_1992'] = df['FIRE_YEAR'] - min_year
|
| 110 |
+
|
| 111 |
+
# Geospatial features
|
| 112 |
+
lat_min, lat_max = 24.0, 50.0
|
| 113 |
+
lon_min, lon_max = -125.0, -66.0
|
| 114 |
+
lat_edges = np.linspace(lat_min, lat_max, LAT_BINS + 1)
|
| 115 |
+
lon_edges = np.linspace(lon_min, lon_max, LON_BINS + 1)
|
| 116 |
+
|
| 117 |
+
df['lat_bin'] = pd.cut(df['LATITUDE'], bins=lat_edges, labels=False, include_lowest=True)
|
| 118 |
+
df['lon_bin'] = pd.cut(df['LONGITUDE'], bins=lon_edges, labels=False, include_lowest=True)
|
| 119 |
+
df['lat_bin'] = df['lat_bin'].fillna(5).astype(int)
|
| 120 |
+
df['lon_bin'] = df['lon_bin'].fillna(5).astype(int)
|
| 121 |
+
|
| 122 |
+
# Coordinate features
|
| 123 |
+
df['lat_squared'] = df['LATITUDE'] ** 2
|
| 124 |
+
df['lon_squared'] = df['LONGITUDE'] ** 2
|
| 125 |
+
df['lat_lon_interaction'] = df['LATITUDE'] * df['LONGITUDE']
|
| 126 |
+
|
| 127 |
+
center_lat, center_lon = 39.8, -98.6
|
| 128 |
+
df['dist_from_center'] = np.sqrt(
|
| 129 |
+
(df['LATITUDE'] - center_lat) ** 2 +
|
| 130 |
+
(df['LONGITUDE'] - center_lon) ** 2
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Placeholder for geo_cluster (would need kmeans model)
|
| 134 |
+
df['geo_cluster'] = 0
|
| 135 |
+
|
| 136 |
+
# Drop temporary columns
|
| 137 |
+
df = df.drop(columns=['temp_date'], errors='ignore')
|
| 138 |
+
|
| 139 |
+
return df
|
| 140 |
+
|
| 141 |
+
def _encode_categoricals(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 142 |
+
"""Encode categorical variables."""
|
| 143 |
+
df = df.copy()
|
| 144 |
+
|
| 145 |
+
# Simple label encoding for inference
|
| 146 |
+
# In production, would need to use same encoders as training
|
| 147 |
+
for col in CATEGORICAL_FEATURES:
|
| 148 |
+
encoded_col = f'{col}_encoded'
|
| 149 |
+
if col in df.columns:
|
| 150 |
+
# Simple hash-based encoding as fallback
|
| 151 |
+
df[encoded_col] = df[col].astype(str).apply(lambda x: hash(x) % 100)
|
| 152 |
+
else:
|
| 153 |
+
df[encoded_col] = 0
|
| 154 |
+
|
| 155 |
+
return df
|
| 156 |
+
|
| 157 |
+
def preprocess(self, df: pd.DataFrame) -> np.ndarray:
|
| 158 |
+
"""Preprocess data for prediction."""
|
| 159 |
+
df = self._create_features(df)
|
| 160 |
+
df = self._encode_categoricals(df)
|
| 161 |
+
|
| 162 |
+
# Select and order features to match training
|
| 163 |
+
missing_features = [f for f in self.feature_names if f not in df.columns]
|
| 164 |
+
if missing_features:
|
| 165 |
+
print(f"Warning: Missing features (filled with 0): {missing_features}")
|
| 166 |
+
for f in missing_features:
|
| 167 |
+
df[f] = 0
|
| 168 |
+
|
| 169 |
+
X = df[self.feature_names].values
|
| 170 |
+
return X
|
| 171 |
+
|
| 172 |
+
def predict(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 173 |
+
"""Generate predictions for input data."""
|
| 174 |
+
X = self.preprocess(df)
|
| 175 |
+
|
| 176 |
+
# Get probabilities
|
| 177 |
+
proba = self.model.predict(X)
|
| 178 |
+
pred_class = np.argmax(proba, axis=1)
|
| 179 |
+
|
| 180 |
+
# Create results dataframe
|
| 181 |
+
results = df.copy()
|
| 182 |
+
results['predicted_class'] = pred_class
|
| 183 |
+
results['predicted_label'] = [TARGET_CLASS_NAMES[c] for c in pred_class]
|
| 184 |
+
results['prob_small'] = proba[:, 0]
|
| 185 |
+
results['prob_medium'] = proba[:, 1]
|
| 186 |
+
results['prob_large'] = proba[:, 2]
|
| 187 |
+
results['confidence'] = np.max(proba, axis=1)
|
| 188 |
+
|
| 189 |
+
return results
|
| 190 |
+
|
| 191 |
+
def predict_single(self, latitude: float, longitude: float,
|
| 192 |
+
fire_year: int, discovery_doy: int,
|
| 193 |
+
state: str = 'Unknown',
|
| 194 |
+
cause: str = 'Unknown',
|
| 195 |
+
agency: str = 'Unknown',
|
| 196 |
+
owner: str = 'Unknown') -> dict:
|
| 197 |
+
"""Predict for a single fire event."""
|
| 198 |
+
|
| 199 |
+
df = pd.DataFrame([{
|
| 200 |
+
'LATITUDE': latitude,
|
| 201 |
+
'LONGITUDE': longitude,
|
| 202 |
+
'FIRE_YEAR': fire_year,
|
| 203 |
+
'DISCOVERY_DOY': discovery_doy,
|
| 204 |
+
'STATE': state,
|
| 205 |
+
'STAT_CAUSE_DESCR': cause,
|
| 206 |
+
'NWCG_REPORTING_AGENCY': agency,
|
| 207 |
+
'OWNER_DESCR': owner
|
| 208 |
+
}])
|
| 209 |
+
|
| 210 |
+
result = self.predict(df).iloc[0]
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
'predicted_class': int(result['predicted_class']),
|
| 214 |
+
'predicted_label': result['predicted_label'],
|
| 215 |
+
'probabilities': {
|
| 216 |
+
'Small': float(result['prob_small']),
|
| 217 |
+
'Medium': float(result['prob_medium']),
|
| 218 |
+
'Large': float(result['prob_large'])
|
| 219 |
+
},
|
| 220 |
+
'confidence': float(result['confidence'])
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
"""Main prediction script."""
|
| 226 |
+
parser = argparse.ArgumentParser(description='Wildfire size prediction')
|
| 227 |
+
|
| 228 |
+
# Single prediction arguments
|
| 229 |
+
parser.add_argument('--lat', type=float, help='Latitude')
|
| 230 |
+
parser.add_argument('--lon', type=float, help='Longitude')
|
| 231 |
+
parser.add_argument('--year', type=int, default=2015, help='Fire year')
|
| 232 |
+
parser.add_argument('--doy', type=int, default=200, help='Day of year')
|
| 233 |
+
parser.add_argument('--state', type=str, default='Unknown', help='State code')
|
| 234 |
+
parser.add_argument('--cause', type=str, default='Unknown', help='Fire cause')
|
| 235 |
+
|
| 236 |
+
# Batch prediction arguments
|
| 237 |
+
parser.add_argument('--input', type=str, help='Input CSV file for batch prediction')
|
| 238 |
+
parser.add_argument('--output', type=str, help='Output CSV file for predictions')
|
| 239 |
+
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
|
| 242 |
+
# Initialize predictor
|
| 243 |
+
predictor = WildfirePredictor()
|
| 244 |
+
|
| 245 |
+
if args.input:
|
| 246 |
+
# Batch prediction
|
| 247 |
+
print(f"\nProcessing batch predictions from: {args.input}")
|
| 248 |
+
df = pd.read_csv(args.input)
|
| 249 |
+
results = predictor.predict(df)
|
| 250 |
+
|
| 251 |
+
output_path = args.output or 'predictions.csv'
|
| 252 |
+
results.to_csv(output_path, index=False)
|
| 253 |
+
print(f"Predictions saved to: {output_path}")
|
| 254 |
+
|
| 255 |
+
elif args.lat is not None and args.lon is not None:
|
| 256 |
+
# Single prediction
|
| 257 |
+
print("\n" + "="*60)
|
| 258 |
+
print("SINGLE FIRE PREDICTION")
|
| 259 |
+
print("="*60)
|
| 260 |
+
|
| 261 |
+
result = predictor.predict_single(
|
| 262 |
+
latitude=args.lat,
|
| 263 |
+
longitude=args.lon,
|
| 264 |
+
fire_year=args.year,
|
| 265 |
+
discovery_doy=args.doy,
|
| 266 |
+
state=args.state,
|
| 267 |
+
cause=args.cause
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
print(f"\nInput:")
|
| 271 |
+
print(f" Location: ({args.lat}, {args.lon})")
|
| 272 |
+
print(f" Year: {args.year}, Day of Year: {args.doy}")
|
| 273 |
+
print(f" State: {args.state}, Cause: {args.cause}")
|
| 274 |
+
|
| 275 |
+
print(f"\nPrediction:")
|
| 276 |
+
print(f" Class: {result['predicted_class']} ({result['predicted_label']})")
|
| 277 |
+
print(f" Confidence: {result['confidence']:.1%}")
|
| 278 |
+
|
| 279 |
+
print(f"\nProbabilities:")
|
| 280 |
+
for label, prob in result['probabilities'].items():
|
| 281 |
+
bar = '█' * int(prob * 20)
|
| 282 |
+
print(f" {label:>6}: {prob:>6.1%} {bar}")
|
| 283 |
+
|
| 284 |
+
else:
|
| 285 |
+
# Demo prediction
|
| 286 |
+
print("\n" + "="*60)
|
| 287 |
+
print("DEMO PREDICTION")
|
| 288 |
+
print("="*60)
|
| 289 |
+
|
| 290 |
+
# Example: Summer fire in California
|
| 291 |
+
result = predictor.predict_single(
|
| 292 |
+
latitude=34.05,
|
| 293 |
+
longitude=-118.24,
|
| 294 |
+
fire_year=2015,
|
| 295 |
+
discovery_doy=200, # Mid-July
|
| 296 |
+
state='CA',
|
| 297 |
+
cause='Debris Burning'
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
print("\nExample: Summer fire in Los Angeles area")
|
| 301 |
+
print(f" Predicted: {result['predicted_label']} (confidence: {result['confidence']:.1%})")
|
| 302 |
+
print(f" Probabilities: Small={result['probabilities']['Small']:.1%}, "
|
| 303 |
+
f"Medium={result['probabilities']['Medium']:.1%}, "
|
| 304 |
+
f"Large={result['probabilities']['Large']:.1%}")
|
| 305 |
+
|
| 306 |
+
print("\nUsage:")
|
| 307 |
+
print(" Single: python 07_predict.py --lat 34.05 --lon -118.24 --state CA --cause 'Lightning'")
|
| 308 |
+
print(" Batch: python 07_predict.py --input fires.csv --output predictions.csv")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
main()
|