zakaneki commited on
Commit
64e892b
·
verified ·
1 Parent(s): bdcfc4e

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ FPA_FOD_20170508.sqlite filter=lfs diff=lfs merge=lfs -text
37
+ models/wildfire_model.txt filter=lfs diff=lfs merge=lfs -text
38
+ reports/figures/geographic_distribution.png filter=lfs diff=lfs merge=lfs -text
39
+ reports/figures/shap_importance_summary.png filter=lfs diff=lfs merge=lfs -text
40
+ reports/figures/temporal_patterns.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ venv/
8
+ env/
9
+ .venv/
10
+ .env/
11
+
12
+ # IDE settings
13
+ .vscode/
14
+ .idea/
15
+ *.swp
16
+ *.swo
17
+
18
+ # Jupyter Notebooks
19
+ .ipynb_checkpoints/
20
+
21
+ # Data files
22
+ *.sqlite
23
+ *.db
24
+ *.parquet
25
+ data/processed/
26
+ data/raw/
27
+
28
+ # Model artifacts
29
+ models/*.txt
30
+ models/*.joblib
31
+ models/*.json
32
+ models/*.pkl
33
+
34
+ # Reports and figures
35
+ reports/figures/*.png
36
+ reports/figures/*.csv
37
+
38
+ # OS files
39
+ .DS_Store
40
+ Thumbs.db
41
+
42
+ # Logs
43
+ *.log
44
+
45
+ # Environment variables
46
+ .env
47
+ .env.local
FPA_FOD_20170508.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f04b23a24989770ce05fa354662b03e597ad164ddf5b7932b8a53b46d0ed428b
3
+ size 795785216
README.md CHANGED
@@ -1,3 +1,156 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wildfire Size Classification Project
2
+
3
+ Predicting wildfire size classes using machine learning on the FPA FOD (Fire Program Analysis Fire-Occurrence Database) containing 1.88 million US wildfire records from 1992-2015.
4
+
5
+ ## Project Overview
6
+
7
+ This project builds an **ordinal classification model** to predict fire size categories:
8
+ - **Small** (0-9.9 acres): Original classes A + B
9
+ - **Medium** (10-299 acres): Original classes C + D
10
+ - **Large** (300+ acres): Original classes E + F + G
11
+
12
+ ### Key Features
13
+ - **Ordinal-aware classification**: Leverages the natural ordering of fire size classes
14
+ - **Geospatial features**: Coordinate clustering, regional binning, distance metrics
15
+ - **Temporal features**: Cyclical encoding of month/day, fire season indicators
16
+ - **Class imbalance handling**: Balanced class weights for rare large fire events
17
+ - **Interpretable results**: SHAP feature importance analysis
18
+
19
+ ## Project Structure
20
+
21
+ ```
22
+ wildfires/
23
+ ├── config/
24
+ │ ├── __init__.py # Package init
25
+ │ └── config.py # Configuration settings
26
+ ├── data/
27
+ │ └── processed/ # Processed parquet files (train/test splits)
28
+ ├── models/ # Saved model artifacts
29
+ │ ├── best_params.json # Tuned hyperparameters
30
+ │ ├── model_metadata.joblib # Feature names and metrics
31
+ │ └── wildfire_model.txt # Trained LightGBM model
32
+ ├── reports/
33
+ │ └── figures/ # Visualizations and metrics
34
+ ├── scripts/
35
+ │ ├── 01_extract_data.py # Extract SQLite → Parquet
36
+ │ ├── 02_eda.py # Exploratory data analysis
37
+ │ ├── 03_preprocess.py # Data preprocessing
38
+ │ ├── 04_feature_engineering.py # Feature creation
39
+ │ ├── 05_train_model.py # Model training
40
+ │ ├── 06_evaluate.py # Model evaluation
41
+ │ └── 07_predict.py # Prediction pipeline
42
+ ├── run_pipeline.py # Run full or partial pipeline
43
+ ├── requirements.txt # Dependencies
44
+ ├── .gitignore # Git ignore rules
45
+ └── README.md
46
+ ```
47
+
48
+ ## Getting Started
49
+
50
+ ### Prerequisites
51
+ - Python 3.9+
52
+ - SQLite database file (`FPA_FOD_20170508.sqlite`)
53
+
54
+ ### Installation
55
+
56
+ 1. Clone/download the repository
57
+ 2. Create a virtual environment:
58
+ ```bash
59
+ python -m venv venv
60
+ venv\Scripts\activate # Windows
61
+ # source venv/bin/activate # Linux/Mac
62
+ ```
63
+ 3. Install dependencies:
64
+ ```bash
65
+ pip install -r requirements.txt
66
+ ```
67
+ 4. Place the SQLite database file in the project root
68
+
69
+ ### Running the Pipeline
70
+
71
+ **Using the pipeline runner (recommended):**
72
+
73
+ ```bash
74
+ # Run full pipeline
75
+ python run_pipeline.py
76
+
77
+ # Skip EDA step
78
+ python run_pipeline.py --skip-eda
79
+
80
+ # Run with hyperparameter tuning
81
+ python run_pipeline.py --tune
82
+
83
+ # Resume from a specific step (1-7)
84
+ python run_pipeline.py --from-step 5
85
+ ```
86
+
87
+ **Or execute scripts individually:**
88
+
89
+ ```bash
90
+ # 1. Extract data from SQLite
91
+ python scripts/01_extract_data.py
92
+
93
+ # 2. Exploratory data analysis (generates plots)
94
+ python scripts/02_eda.py
95
+
96
+ # 3. Preprocess data
97
+ python scripts/03_preprocess.py
98
+
99
+ # 4. Feature engineering
100
+ python scripts/04_feature_engineering.py
101
+
102
+ # 5. Train model (add --tune for hyperparameter tuning)
103
+ python scripts/05_train_model.py
104
+ # python scripts/05_train_model.py --tune # With Optuna tuning
105
+
106
+ # 6. Evaluate model
107
+ python scripts/06_evaluate.py
108
+
109
+ # 7. Make predictions
110
+ python scripts/07_predict.py --lat 34.05 --lon -118.24 --state CA --cause "Lightning"
111
+ ```
112
+
113
+ ## Model Details
114
+
115
+ ### Features Used
116
+ - **Temporal**: Month, day of week, season, fire season indicator (cyclically encoded)
117
+ - **Geospatial**: Lat/lon coordinates, regional clusters (K-means), coordinate bins
118
+ - **Categorical**: State, fire cause, reporting agency, land owner
119
+ - **Year**: Fire year, years since 1992
120
+
121
+ ### Algorithm
122
+ - **LightGBM** gradient boosting for multi-class classification
123
+ - Class weights to handle imbalanced data (~90% small fires)
124
+ - Linear weighted Cohen's Kappa for ordinal evaluation
125
+
126
+ ### Expected Performance
127
+ - Balanced Accuracy: ~65-75%
128
+ - Macro F1 Score: ~0.45-0.55
129
+ - Large fire detection is challenging due to class imbalance
130
+
131
+ ## Evaluation Metrics
132
+
133
+ For ordinal classification, we prioritize:
134
+ - **Macro F1**: Equal importance to all classes
135
+ - **Balanced Accuracy**: Accounts for class imbalance
136
+ - **Linear Weighted Kappa**: Penalizes predictions far from true class
137
+
138
+ ## Output Files
139
+
140
+ After running the pipeline:
141
+ - `data/processed/`: Parquet files for train/test splits
142
+ - `models/wildfire_model.txt`: Trained LightGBM model
143
+ - `models/model_metadata.joblib`: Feature names and metrics
144
+ - `reports/figures/`: Visualizations (confusion matrix, SHAP plots, etc.)
145
+
146
+ ## Data Source
147
+
148
+ **Fire Program Analysis Fire-Occurrence Database (FPA FOD)**
149
+ - 1.88 million geo-referenced wildfire records
150
+ - Period: 1992-2015
151
+ - 140 million acres burned
152
+ - Source: US federal, state, and local fire organizations
153
+
154
+ ## License
155
+
156
+ This project uses publicly available government data.
config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .config import *
config/config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the Wildfire Size Classification project.
3
+
4
+ Target: Ordinal classification of fire size into 3 classes:
5
+ - 0 (Small): Classes A + B (0 - 9.9 acres)
6
+ - 1 (Medium): Classes C + D (10 - 299 acres)
7
+ - 2 (Large): Classes E + F + G (300+ acres)
8
+ """
9
+
10
+ from pathlib import Path
11
+
12
+ # =============================================================================
13
+ # PATHS
14
+ # =============================================================================
15
+
16
+ # Project root
17
+ PROJECT_ROOT = Path(__file__).parent.parent
18
+
19
+ # Data paths
20
+ DATA_DIR = PROJECT_ROOT / "data"
21
+ RAW_DATA_DIR = DATA_DIR / "raw"
22
+ PROCESSED_DATA_DIR = DATA_DIR / "processed"
23
+
24
+ # SQLite database path (adjust filename if different)
25
+ SQLITE_DB_PATH = PROJECT_ROOT / "FPA_FOD_20170508.sqlite"
26
+
27
+ # Output paths
28
+ MODELS_DIR = PROJECT_ROOT / "models"
29
+ REPORTS_DIR = PROJECT_ROOT / "reports"
30
+ FIGURES_DIR = REPORTS_DIR / "figures"
31
+
32
+ # Processed data files
33
+ RAW_PARQUET = PROCESSED_DATA_DIR / "fires_raw.parquet"
34
+ PROCESSED_PARQUET = PROCESSED_DATA_DIR / "fires_processed.parquet"
35
+ FEATURES_PARQUET = PROCESSED_DATA_DIR / "fires_features.parquet"
36
+ TRAIN_PARQUET = PROCESSED_DATA_DIR / "train.parquet"
37
+ TEST_PARQUET = PROCESSED_DATA_DIR / "test.parquet"
38
+
39
+ # =============================================================================
40
+ # TARGET VARIABLE CONFIGURATION
41
+ # =============================================================================
42
+
43
+ # Original fire size classes mapping to ordinal target
44
+ # A (0-0.25 acres), B (0.26-9.9 acres) -> 0 (Small)
45
+ # C (10-99.9 acres), D (100-299 acres) -> 1 (Medium)
46
+ # E (300-999 acres), F (1000-4999 acres), G (5000+ acres) -> 2 (Large)
47
+
48
+ FIRE_SIZE_CLASS_MAPPING = {
49
+ 'A': 0, 'B': 0, # Small
50
+ 'C': 1, 'D': 1, # Medium
51
+ 'E': 2, 'F': 2, 'G': 2 # Large
52
+ }
53
+
54
+ TARGET_CLASS_NAMES = ['Small', 'Medium', 'Large']
55
+ TARGET_COLUMN = 'fire_size_ordinal'
56
+ ORIGINAL_TARGET_COLUMN = 'FIRE_SIZE_CLASS'
57
+
58
+ # =============================================================================
59
+ # FEATURE CONFIGURATION
60
+ # =============================================================================
61
+
62
+ # Columns to drop (IDs, redundant info, text fields)
63
+ COLUMNS_TO_DROP = [
64
+ 'FOD_ID', 'FPA_ID', 'SOURCE_SYSTEM_TYPE', 'SOURCE_SYSTEM',
65
+ 'NWCG_REPORTING_UNIT_ID', 'NWCG_REPORTING_UNIT_NAME',
66
+ 'SOURCE_REPORTING_UNIT', 'SOURCE_REPORTING_UNIT_NAME',
67
+ 'LOCAL_FIRE_REPORT_ID', 'LOCAL_INCIDENT_ID',
68
+ 'FIRE_CODE', 'FIRE_NAME',
69
+ 'ICS_209_INCIDENT_NUMBER', 'ICS_209_NAME',
70
+ 'MTBS_ID', 'MTBS_FIRE_NAME', 'COMPLEX_NAME',
71
+ 'DISCOVERY_DATE', 'DISCOVERY_TIME',
72
+ 'CONT_DATE', 'CONT_DOY', 'CONT_TIME',
73
+ 'FIPS_CODE', 'FIPS_NAME',
74
+ 'FIRE_SIZE', # Don't use actual size as feature - it's what we're predicting
75
+ 'FIRE_SIZE_CLASS', # Original target
76
+ 'Shape' # Geometry column if present
77
+ ]
78
+
79
+ # Categorical features to encode
80
+ CATEGORICAL_FEATURES = [
81
+ 'NWCG_REPORTING_AGENCY',
82
+ 'STAT_CAUSE_DESCR',
83
+ 'STATE',
84
+ 'OWNER_DESCR'
85
+ ]
86
+
87
+ # Numerical features (after feature engineering)
88
+ NUMERICAL_FEATURES = [
89
+ 'LATITUDE',
90
+ 'LONGITUDE',
91
+ 'DISCOVERY_DOY',
92
+ 'FIRE_YEAR'
93
+ ]
94
+
95
+ # Temporal features to create
96
+ TEMPORAL_FEATURES = [
97
+ 'month',
98
+ 'season',
99
+ 'day_of_week',
100
+ 'is_weekend'
101
+ ]
102
+
103
+ # Geospatial features to create
104
+ GEOSPATIAL_FEATURES = [
105
+ 'lat_bin',
106
+ 'lon_bin',
107
+ 'geo_cluster',
108
+ 'lat_lon_interaction'
109
+ ]
110
+
111
+ # =============================================================================
112
+ # MODEL CONFIGURATION
113
+ # =============================================================================
114
+
115
+ # Random seed for reproducibility
116
+ RANDOM_STATE = 42
117
+
118
+ # Train/test split ratio
119
+ TEST_SIZE = 0.2
120
+
121
+ # Cross-validation folds
122
+ N_FOLDS = 5
123
+
124
+ # Class weights for imbalanced data (will be computed dynamically)
125
+ USE_CLASS_WEIGHTS = True
126
+
127
+ # LightGBM base parameters for ordinal classification
128
+ LIGHTGBM_PARAMS = {
129
+ 'objective': 'multiclass',
130
+ 'num_class': 3,
131
+ 'metric': 'multi_logloss',
132
+ 'boosting_type': 'gbdt',
133
+ 'verbosity': -1,
134
+ 'random_state': RANDOM_STATE,
135
+ 'n_jobs': -1
136
+ }
137
+
138
+ # Optuna hyperparameter search space
139
+ OPTUNA_SEARCH_SPACE = {
140
+ 'n_estimators': (100, 1000),
141
+ 'max_depth': (3, 12),
142
+ 'learning_rate': (0.01, 0.3),
143
+ 'num_leaves': (20, 150),
144
+ 'min_child_samples': (10, 100),
145
+ 'subsample': (0.6, 1.0),
146
+ 'colsample_bytree': (0.6, 1.0),
147
+ 'reg_alpha': (0.0, 1.0),
148
+ 'reg_lambda': (0.0, 1.0)
149
+ }
150
+
151
+ # Number of Optuna trials
152
+ N_OPTUNA_TRIALS = 50
153
+
154
+ # =============================================================================
155
+ # GEOSPATIAL CLUSTERING CONFIGURATION
156
+ # =============================================================================
157
+
158
+ # Number of clusters for geographic regions
159
+ N_GEO_CLUSTERS = 20
160
+
161
+ # Latitude/Longitude binning
162
+ LAT_BINS = 10
163
+ LON_BINS = 10
164
+
165
+ # =============================================================================
166
+ # EVALUATION METRICS
167
+ # =============================================================================
168
+
169
+ # Primary metric for model selection
170
+ PRIMARY_METRIC = 'macro_f1'
171
+
172
+ # All metrics to compute
173
+ EVALUATION_METRICS = [
174
+ 'accuracy',
175
+ 'balanced_accuracy',
176
+ 'macro_f1',
177
+ 'weighted_f1',
178
+ 'cohen_kappa',
179
+ 'macro_precision',
180
+ 'macro_recall'
181
+ ]
data/processed/fires_features.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6926ded384d61d2cd531853ce79785befc12cbb9d827ebf89e5f334aefef5c57
3
+ size 116607705
data/processed/fires_processed.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:245d6fccc065f5981420244d0aa7a6d0cd12cbe9f89753852b12eba03978c75f
3
+ size 26619137
data/processed/fires_raw.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:841e692f7f4044513d96c762f726a276e165a009a7f6ecb8d05547d8cee59cf0
3
+ size 127864657
data/processed/test.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2c42ac67caff94f6882dc3f7f00b3e5a30f37d19672d4fa5bd279795622c7df
3
+ size 23840600
data/processed/train.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb2678f4b14baf3863f32818b4030253ac731587c1f3451579dc38c1dad691d2
3
+ size 93657151
models/best_params.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_estimators": 516,
3
+ "max_depth": 9,
4
+ "learning_rate": 0.21074359142553023,
5
+ "num_leaves": 131,
6
+ "min_child_samples": 70,
7
+ "subsample": 0.7796691167092181,
8
+ "colsample_bytree": 0.7018466944576152,
9
+ "reg_alpha": 0.52062185793523,
10
+ "reg_lambda": 0.3425483694161869
11
+ }
models/model_metadata.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37b351a57637f9c738d07be623d8cfbfe9a4a996f1e958c9e63883e023527925
3
+ size 958
models/wildfire_model.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b1a077debe29e736c3382e934aa8b6e5c7cf235e8707d680bb8664ab5317d99
3
+ size 21722535
reports/figures/cause_by_size.png ADDED
reports/figures/class_distribution.png ADDED
reports/figures/classification_metrics.png ADDED
reports/figures/confusion_matrix.png ADDED
reports/figures/feature_importance.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ feature,importance
2
+ day_of_week,0.03474360040175998
3
+ OWNER_DESCR_encoded,0.035002645531480546
4
+ STATE_encoded,0.04257236407411825
5
+ is_weekend,0.042745448253639684
6
+ lat_squared,0.0492982200647884
7
+ month_cos,0.04941415345320393
8
+ dist_from_center,0.04974556334097185
9
+ years_since_1992,0.05059611732831027
10
+ lat_bin,0.05266350693152481
11
+ month_sin,0.05535090135037826
12
+ NWCG_REPORTING_AGENCY_encoded,0.06568755534754865
13
+ dow_sin,0.06723634255462703
14
+ doy_sin,0.06948528640895639
15
+ year_normalized,0.07218163313544536
16
+ month,0.07541548211445999
17
+ lat_lon_interaction,0.08870144324496855
18
+ doy_cos,0.09805861682638295
19
+ lon_bin,0.0997225819136466
20
+ LONGITUDE,0.10201343368889876
21
+ season,0.10830839963123856
22
+ FIRE_YEAR,0.10860962139407908
23
+ is_fire_season,0.12894166850207692
24
+ lon_squared,0.13777922755612584
25
+ DISCOVERY_DOY,0.14014555480776913
26
+ geo_cluster,0.1491778952047764
27
+ STAT_CAUSE_DESCR_encoded,0.15889478851167435
28
+ LATITUDE,0.17045380361507326
29
+ dow_cos,0.23151901964397537
reports/figures/geographic_distribution.png ADDED

Git LFS Details

  • SHA256: 422fd6f65aec85c0498d0df22d98e82882816247ab1cbf0619e2999079fa07b9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
reports/figures/missing_values.png ADDED
reports/figures/prediction_distribution.png ADDED
reports/figures/shap_importance.png ADDED
reports/figures/shap_importance_summary.png ADDED

Git LFS Details

  • SHA256: 76732c25c546c08ae6a3084ccba961eab87b6bc730805b051e02d371dd4786c3
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
reports/figures/temporal_patterns.png ADDED

Git LFS Details

  • SHA256: 6d0ecf15472ae0e5995a509d0343ee9a0da2e70f664002c07e0542409c661bd6
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core data processing
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ pyarrow>=14.0.0
5
+
6
+ # Machine learning
7
+ scikit-learn>=1.3.0
8
+ lightgbm>=4.0.0
9
+ xgboost>=2.0.0
10
+ imbalanced-learn>=0.11.0
11
+
12
+ # Hyperparameter tuning
13
+ optuna>=3.4.0
14
+
15
+ # Model interpretability
16
+ shap>=0.43.0
17
+
18
+ # Visualization
19
+ matplotlib>=3.7.0
20
+ seaborn>=0.12.0
21
+
22
+ # Geospatial (optional clustering)
23
+ scikit-learn # KMeans for coordinate clustering
24
+
25
+ # Progress bars
26
+ tqdm>=4.66.0
27
+
28
+ # Joblib for model persistence
29
+ joblib>=1.3.0
run_pipeline.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Pipeline Runner
3
+
4
+ Runs the complete ML pipeline from data extraction to evaluation.
5
+
6
+ Usage:
7
+ python run_pipeline.py # Run full pipeline
8
+ python run_pipeline.py --skip-eda # Skip EDA step
9
+ python run_pipeline.py --tune # Include hyperparameter tuning
10
+ """
11
+
12
+ import argparse
13
+ import subprocess
14
+ import sys
15
+ from pathlib import Path
16
+
17
+
18
+ def run_script(script_name: str, extra_args: list = None) -> bool:
19
+ """Run a Python script and return success status."""
20
+ script_path = Path(__file__).parent / "scripts" / script_name
21
+
22
+ if not script_path.exists():
23
+ print(f"ERROR: Script not found: {script_path}")
24
+ return False
25
+
26
+ cmd = [sys.executable, str(script_path)]
27
+ if extra_args:
28
+ cmd.extend(extra_args)
29
+
30
+ print(f"\n{'='*60}")
31
+ print(f"Running: {script_name}")
32
+ print(f"{'='*60}\n")
33
+
34
+ result = subprocess.run(cmd, cwd=str(Path(__file__).parent))
35
+
36
+ if result.returncode != 0:
37
+ print(f"\nERROR: {script_name} failed with return code {result.returncode}")
38
+ return False
39
+
40
+ return True
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser(description="Run wildfire ML pipeline")
45
+ parser.add_argument("--skip-eda", action="store_true", help="Skip EDA step")
46
+ parser.add_argument("--tune", action="store_true", help="Run hyperparameter tuning")
47
+ parser.add_argument("--from-step", type=int, default=1, help="Start from step number (1-7)")
48
+ args = parser.parse_args()
49
+
50
+ print("\n" + "="*60)
51
+ print("WILDFIRE SIZE CLASSIFICATION PIPELINE")
52
+ print("="*60)
53
+
54
+ steps = [
55
+ ("01_extract_data.py", []),
56
+ ("02_eda.py", []),
57
+ ("03_preprocess.py", []),
58
+ ("04_feature_engineering.py", []),
59
+ ("05_train_model.py", ["--tune"] if args.tune else []),
60
+ ("06_evaluate.py", []),
61
+ ]
62
+
63
+ for i, (script, extra_args) in enumerate(steps, 1):
64
+ if i < args.from_step:
65
+ print(f"\nSkipping step {i}: {script}")
66
+ continue
67
+
68
+ if args.skip_eda and "eda" in script:
69
+ print(f"\nSkipping EDA step: {script}")
70
+ continue
71
+
72
+ success = run_script(script, extra_args)
73
+ if not success:
74
+ print(f"\nPipeline failed at step {i}: {script}")
75
+ sys.exit(1)
76
+
77
+ print("\n" + "="*60)
78
+ print("✓ PIPELINE COMPLETE!")
79
+ print("="*60)
80
+ print("\nOutputs:")
81
+ print(" - Model: models/wildfire_model.txt")
82
+ print(" - Figures: reports/figures/")
83
+ print(" - Data: data/processed/")
84
+ print("\nNext steps:")
85
+ print(" - Review figures in reports/figures/")
86
+ print(" - Make predictions: python scripts/07_predict.py --lat 34.05 --lon -118.24")
87
+ print("="*60 + "\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
scripts/01_extract_data.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 01: Extract Data from SQLite Database
3
+
4
+ This script connects to the FPA FOD SQLite database, extracts the Fires table,
5
+ and saves it as a Parquet file for faster subsequent processing.
6
+
7
+ Usage:
8
+ python scripts/01_extract_data.py
9
+ """
10
+
11
+ import sqlite3
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ import pandas as pd
16
+
17
+ # Add project root to path
18
+ project_root = Path(__file__).parent.parent
19
+ sys.path.insert(0, str(project_root))
20
+
21
+ from config.config import (
22
+ SQLITE_DB_PATH,
23
+ PROCESSED_DATA_DIR,
24
+ RAW_PARQUET
25
+ )
26
+
27
+
28
+ def connect_to_database(db_path: Path) -> sqlite3.Connection:
29
+ """Connect to SQLite database."""
30
+ if not db_path.exists():
31
+ raise FileNotFoundError(
32
+ f"Database not found at {db_path}. "
33
+ "Please ensure the SQLite file is in the project root."
34
+ )
35
+
36
+ print(f"Connecting to database: {db_path}")
37
+ return sqlite3.connect(db_path)
38
+
39
+
40
+ def get_table_info(conn: sqlite3.Connection) -> None:
41
+ """Print information about tables in the database."""
42
+ cursor = conn.cursor()
43
+
44
+ # Get list of user tables (skip SpatiaLite system tables)
45
+ cursor.execute("""
46
+ SELECT name FROM sqlite_master
47
+ WHERE type='table' AND name NOT LIKE 'sqlite_%'
48
+ AND name NOT LIKE 'spatial%'
49
+ AND name NOT LIKE 'virt%'
50
+ AND name NOT LIKE 'view%'
51
+ AND name NOT LIKE 'geometry%'
52
+ """)
53
+ tables = cursor.fetchall()
54
+
55
+ print("\n" + "="*60)
56
+ print("DATABASE TABLES")
57
+ print("="*60)
58
+
59
+ for table in tables:
60
+ table_name = table[0]
61
+ try:
62
+ cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
63
+ count = cursor.fetchone()[0]
64
+ print(f" {table_name}: {count:,} rows")
65
+ except Exception as e:
66
+ print(f" {table_name}: (could not read - {str(e)[:30]})")
67
+
68
+ print("="*60 + "\n")
69
+
70
+
71
+ def extract_fires_table(conn: sqlite3.Connection) -> pd.DataFrame:
72
+ """Extract the Fires table from the database."""
73
+ print("Extracting Fires table...")
74
+
75
+ query = "SELECT * FROM Fires"
76
+ df = pd.read_sql_query(query, conn)
77
+
78
+ print(f" Loaded {len(df):,} records")
79
+ print(f" Columns: {len(df.columns)}")
80
+ print(f" Memory usage: {df.memory_usage(deep=True).sum() / 1e6:.1f} MB")
81
+
82
+ return df
83
+
84
+
85
+ def save_to_parquet(df: pd.DataFrame, output_path: Path) -> None:
86
+ """Save DataFrame to Parquet format."""
87
+ # Create directory if it doesn't exist
88
+ output_path.parent.mkdir(parents=True, exist_ok=True)
89
+
90
+ print(f"\nSaving to Parquet: {output_path}")
91
+ df.to_parquet(output_path, index=False, compression='snappy')
92
+
93
+ file_size_mb = output_path.stat().st_size / 1e6
94
+ print(f" File size: {file_size_mb:.1f} MB")
95
+
96
+
97
+ def print_data_summary(df: pd.DataFrame) -> None:
98
+ """Print summary statistics of the extracted data."""
99
+ print("\n" + "="*60)
100
+ print("DATA SUMMARY")
101
+ print("="*60)
102
+
103
+ print(f"\nDate Range: {df['FIRE_YEAR'].min()} - {df['FIRE_YEAR'].max()}")
104
+
105
+ print("\nFire Size Class Distribution:")
106
+ size_dist = df['FIRE_SIZE_CLASS'].value_counts().sort_index()
107
+ for cls, count in size_dist.items():
108
+ pct = count / len(df) * 100
109
+ print(f" Class {cls}: {count:>10,} ({pct:>5.1f}%)")
110
+
111
+ print("\nTop 10 States by Fire Count:")
112
+ state_dist = df['STATE'].value_counts().head(10)
113
+ for state, count in state_dist.items():
114
+ print(f" {state}: {count:,}")
115
+
116
+ print("\nTop Causes:")
117
+ cause_dist = df['STAT_CAUSE_DESCR'].value_counts().head(5)
118
+ for cause, count in cause_dist.items():
119
+ pct = count / len(df) * 100
120
+ print(f" {cause}: {count:,} ({pct:.1f}%)")
121
+
122
+ print("\nMissing Values (top 10 columns):")
123
+ missing = df.isnull().sum().sort_values(ascending=False).head(10)
124
+ for col, count in missing.items():
125
+ if count > 0:
126
+ pct = count / len(df) * 100
127
+ print(f" {col}: {count:,} ({pct:.1f}%)")
128
+
129
+ print("="*60 + "\n")
130
+
131
+
132
+ def main():
133
+ """Main extraction pipeline."""
134
+ print("\n" + "="*60)
135
+ print("WILDFIRE DATA EXTRACTION")
136
+ print("="*60 + "\n")
137
+
138
+ # Connect to database
139
+ conn = connect_to_database(SQLITE_DB_PATH)
140
+
141
+ try:
142
+ # Show database info
143
+ get_table_info(conn)
144
+
145
+ # Extract Fires table
146
+ df = extract_fires_table(conn)
147
+
148
+ # Print summary
149
+ print_data_summary(df)
150
+
151
+ # Save to Parquet
152
+ save_to_parquet(df, RAW_PARQUET)
153
+
154
+ print("\n✓ Data extraction complete!")
155
+ print(f" Output: {RAW_PARQUET}")
156
+
157
+ finally:
158
+ conn.close()
159
+ print(" Database connection closed.")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
scripts/02_eda.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 02: Exploratory Data Analysis (EDA)
3
+
4
+ This script performs comprehensive EDA on the wildfire dataset:
5
+ - Class distribution analysis (original 7 classes and grouped 3 classes)
6
+ - Geographic distribution of fires
7
+ - Temporal patterns (yearly, monthly, seasonal)
8
+ - Missing value analysis
9
+ - Feature correlations
10
+
11
+ Generates visualization plots saved to reports/figures/
12
+
13
+ Usage:
14
+ python scripts/02_eda.py
15
+ """
16
+
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import pandas as pd
23
+ import seaborn as sns
24
+
25
+ # Add project root to path
26
+ project_root = Path(__file__).parent.parent
27
+ sys.path.insert(0, str(project_root))
28
+
29
+ from config.config import (
30
+ RAW_PARQUET,
31
+ FIGURES_DIR,
32
+ FIRE_SIZE_CLASS_MAPPING,
33
+ TARGET_CLASS_NAMES
34
+ )
35
+
36
+ # Set style
37
+ plt.style.use('seaborn-v0_8-whitegrid')
38
+ sns.set_palette("husl")
39
+
40
+
41
+ def load_data() -> pd.DataFrame:
42
+ """Load the raw parquet data."""
43
+ print("Loading data...")
44
+ df = pd.read_parquet(RAW_PARQUET)
45
+ print(f" Loaded {len(df):,} records")
46
+ return df
47
+
48
+
49
+ def analyze_class_distribution(df: pd.DataFrame) -> None:
50
+ """Analyze and visualize fire size class distribution."""
51
+ print("\n" + "="*60)
52
+ print("CLASS DISTRIBUTION ANALYSIS")
53
+ print("="*60)
54
+
55
+ # Original 7 classes
56
+ print("\nOriginal Fire Size Classes:")
57
+ original_dist = df['FIRE_SIZE_CLASS'].value_counts().sort_index()
58
+ for cls, count in original_dist.items():
59
+ pct = count / len(df) * 100
60
+ print(f" Class {cls}: {count:>10,} ({pct:>6.2f}%)")
61
+
62
+ # Grouped 3 classes
63
+ df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
64
+
65
+ print("\nGrouped Classes (Target Variable):")
66
+ grouped_dist = df['fire_size_grouped'].value_counts().sort_index()
67
+ for cls_idx, count in grouped_dist.items():
68
+ pct = count / len(df) * 100
69
+ cls_name = TARGET_CLASS_NAMES[cls_idx]
70
+ print(f" {cls_idx} ({cls_name:>6}): {count:>10,} ({pct:>6.2f}%)")
71
+
72
+ # Visualize
73
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
74
+
75
+ # Original distribution
76
+ colors_orig = sns.color_palette("YlOrRd", 7)
77
+ ax1 = axes[0]
78
+ original_dist.plot(kind='bar', ax=ax1, color=colors_orig, edgecolor='black')
79
+ ax1.set_title('Original Fire Size Class Distribution', fontsize=14, fontweight='bold')
80
+ ax1.set_xlabel('Fire Size Class')
81
+ ax1.set_ylabel('Count')
82
+ ax1.tick_params(axis='x', rotation=0)
83
+
84
+ # Add percentage labels
85
+ for i, (idx, val) in enumerate(original_dist.items()):
86
+ pct = val / len(df) * 100
87
+ ax1.annotate(f'{pct:.1f}%', (i, val), ha='center', va='bottom', fontsize=9)
88
+
89
+ # Grouped distribution
90
+ colors_grouped = ['#2ecc71', '#f39c12', '#e74c3c'] # Green, Orange, Red
91
+ ax2 = axes[1]
92
+ grouped_dist.plot(kind='bar', ax=ax2, color=colors_grouped, edgecolor='black')
93
+ ax2.set_title('Grouped Fire Size Distribution (Target)', fontsize=14, fontweight='bold')
94
+ ax2.set_xlabel('Fire Size Category')
95
+ ax2.set_ylabel('Count')
96
+ ax2.set_xticklabels(TARGET_CLASS_NAMES, rotation=0)
97
+
98
+ # Add percentage labels
99
+ for i, (idx, val) in enumerate(grouped_dist.items()):
100
+ pct = val / len(df) * 100
101
+ ax2.annotate(f'{pct:.1f}%', (i, val), ha='center', va='bottom', fontsize=10)
102
+
103
+ plt.tight_layout()
104
+ plt.savefig(FIGURES_DIR / 'class_distribution.png', dpi=150, bbox_inches='tight')
105
+ plt.close()
106
+ print(f"\n Saved: class_distribution.png")
107
+
108
+
109
+ def analyze_geographic_distribution(df: pd.DataFrame) -> None:
110
+ """Analyze and visualize geographic distribution of fires."""
111
+ print("\n" + "="*60)
112
+ print("GEOGRAPHIC DISTRIBUTION")
113
+ print("="*60)
114
+
115
+ # Top states
116
+ print("\nTop 15 States by Fire Count:")
117
+ state_dist = df['STATE'].value_counts().head(15)
118
+ for state, count in state_dist.items():
119
+ pct = count / len(df) * 100
120
+ print(f" {state}: {count:>10,} ({pct:>5.1f}%)")
121
+
122
+ # Fire locations scatter plot
123
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
124
+
125
+ # All fires (sampled for performance)
126
+ sample_size = min(100000, len(df))
127
+ df_sample = df.sample(n=sample_size, random_state=42)
128
+
129
+ ax1 = axes[0]
130
+ scatter = ax1.scatter(
131
+ df_sample['LONGITUDE'],
132
+ df_sample['LATITUDE'],
133
+ c=df_sample['FIRE_SIZE_CLASS'].map({'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6}),
134
+ cmap='YlOrRd',
135
+ alpha=0.3,
136
+ s=1
137
+ )
138
+ ax1.set_title(f'Fire Locations (n={sample_size:,} sample)', fontsize=14, fontweight='bold')
139
+ ax1.set_xlabel('Longitude')
140
+ ax1.set_ylabel('Latitude')
141
+ ax1.set_xlim(-130, -65)
142
+ ax1.set_ylim(24, 50)
143
+ plt.colorbar(scatter, ax=ax1, label='Fire Size Class (A=0 to G=6)')
144
+
145
+ # Large fires only (E, F, G)
146
+ df_large = df[df['FIRE_SIZE_CLASS'].isin(['E', 'F', 'G'])]
147
+
148
+ ax2 = axes[1]
149
+ scatter2 = ax2.scatter(
150
+ df_large['LONGITUDE'],
151
+ df_large['LATITUDE'],
152
+ c=df_large['FIRE_SIZE_CLASS'].map({'E': 0, 'F': 1, 'G': 2}),
153
+ cmap='Reds',
154
+ alpha=0.5,
155
+ s=5
156
+ )
157
+ ax2.set_title(f'Large Fires Only (E/F/G, n={len(df_large):,})', fontsize=14, fontweight='bold')
158
+ ax2.set_xlabel('Longitude')
159
+ ax2.set_ylabel('Latitude')
160
+ ax2.set_xlim(-130, -65)
161
+ ax2.set_ylim(24, 50)
162
+
163
+ plt.tight_layout()
164
+ plt.savefig(FIGURES_DIR / 'geographic_distribution.png', dpi=150, bbox_inches='tight')
165
+ plt.close()
166
+ print(f"\n Saved: geographic_distribution.png")
167
+
168
+
169
+ def analyze_temporal_patterns(df: pd.DataFrame) -> None:
170
+ """Analyze temporal patterns in the data."""
171
+ print("\n" + "="*60)
172
+ print("TEMPORAL PATTERNS")
173
+ print("="*60)
174
+
175
+ # Convert discovery day of year to month
176
+ df['month'] = pd.to_datetime(df['DISCOVERY_DOY'], format='%j').dt.month
177
+
178
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
179
+
180
+ # Yearly trend
181
+ ax1 = axes[0, 0]
182
+ yearly = df.groupby('FIRE_YEAR').size()
183
+ yearly.plot(kind='line', ax=ax1, marker='o', linewidth=2, markersize=4)
184
+ ax1.set_title('Fires per Year', fontsize=12, fontweight='bold')
185
+ ax1.set_xlabel('Year')
186
+ ax1.set_ylabel('Number of Fires')
187
+ ax1.grid(True, alpha=0.3)
188
+
189
+ # Monthly distribution
190
+ ax2 = axes[0, 1]
191
+ monthly = df.groupby('month').size()
192
+ monthly.plot(kind='bar', ax=ax2, color='coral', edgecolor='black')
193
+ ax2.set_title('Fires by Month', fontsize=12, fontweight='bold')
194
+ ax2.set_xlabel('Month')
195
+ ax2.set_ylabel('Number of Fires')
196
+ ax2.tick_params(axis='x', rotation=0)
197
+
198
+ # Large fires by month
199
+ ax3 = axes[1, 0]
200
+ df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
201
+ monthly_by_class = df.groupby(['month', 'fire_size_grouped']).size().unstack(fill_value=0)
202
+ monthly_by_class.columns = TARGET_CLASS_NAMES
203
+ monthly_by_class.plot(kind='bar', ax=ax3, width=0.8,
204
+ color=['#2ecc71', '#f39c12', '#e74c3c'], edgecolor='black')
205
+ ax3.set_title('Fire Size Category by Month', fontsize=12, fontweight='bold')
206
+ ax3.set_xlabel('Month')
207
+ ax3.set_ylabel('Number of Fires')
208
+ ax3.tick_params(axis='x', rotation=0)
209
+ ax3.legend(title='Size Category')
210
+
211
+ # Fire causes
212
+ ax4 = axes[1, 1]
213
+ cause_dist = df['STAT_CAUSE_DESCR'].value_counts().head(10)
214
+ cause_dist.plot(kind='barh', ax=ax4, color='steelblue', edgecolor='black')
215
+ ax4.set_title('Top 10 Fire Causes', fontsize=12, fontweight='bold')
216
+ ax4.set_xlabel('Number of Fires')
217
+ ax4.invert_yaxis()
218
+
219
+ plt.tight_layout()
220
+ plt.savefig(FIGURES_DIR / 'temporal_patterns.png', dpi=150, bbox_inches='tight')
221
+ plt.close()
222
+ print(f"\n Saved: temporal_patterns.png")
223
+
224
+ # Print monthly stats
225
+ print("\nFires by Month:")
226
+ month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
227
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
228
+ for month, count in monthly.items():
229
+ pct = count / len(df) * 100
230
+ print(f" {month_names[month-1]}: {count:>10,} ({pct:>5.1f}%)")
231
+
232
+
233
+ def analyze_missing_values(df: pd.DataFrame) -> None:
234
+ """Analyze missing values in the dataset."""
235
+ print("\n" + "="*60)
236
+ print("MISSING VALUE ANALYSIS")
237
+ print("="*60)
238
+
239
+ missing = df.isnull().sum()
240
+ missing_pct = (missing / len(df) * 100).round(2)
241
+
242
+ missing_df = pd.DataFrame({
243
+ 'Missing Count': missing,
244
+ 'Missing %': missing_pct
245
+ }).sort_values('Missing Count', ascending=False)
246
+
247
+ # Only show columns with missing values
248
+ missing_df = missing_df[missing_df['Missing Count'] > 0]
249
+
250
+ print(f"\nColumns with missing values: {len(missing_df)}")
251
+ print("\nTop 20 columns with missing values:")
252
+ for col, row in missing_df.head(20).iterrows():
253
+ print(f" {col}: {row['Missing Count']:,} ({row['Missing %']:.1f}%)")
254
+
255
+ # Visualize
256
+ if len(missing_df) > 0:
257
+ fig, ax = plt.subplots(figsize=(12, 8))
258
+ missing_df.head(20)['Missing %'].plot(
259
+ kind='barh', ax=ax, color='salmon', edgecolor='black'
260
+ )
261
+ ax.set_title('Missing Values by Column (Top 20)', fontsize=14, fontweight='bold')
262
+ ax.set_xlabel('Missing %')
263
+ ax.invert_yaxis()
264
+
265
+ plt.tight_layout()
266
+ plt.savefig(FIGURES_DIR / 'missing_values.png', dpi=150, bbox_inches='tight')
267
+ plt.close()
268
+ print(f"\n Saved: missing_values.png")
269
+
270
+
271
+ def analyze_cause_by_size(df: pd.DataFrame) -> None:
272
+ """Analyze fire causes by fire size category."""
273
+ print("\n" + "="*60)
274
+ print("FIRE CAUSE BY SIZE ANALYSIS")
275
+ print("="*60)
276
+
277
+ df['fire_size_grouped'] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
278
+
279
+ # Cross-tabulation
280
+ cause_size = pd.crosstab(
281
+ df['STAT_CAUSE_DESCR'],
282
+ df['fire_size_grouped'],
283
+ normalize='index'
284
+ ) * 100
285
+ cause_size.columns = TARGET_CLASS_NAMES
286
+
287
+ print("\nFire Cause Distribution by Size Category (% of each cause):")
288
+ print(cause_size.round(1).to_string())
289
+
290
+ # Visualize
291
+ fig, ax = plt.subplots(figsize=(12, 8))
292
+ cause_size.plot(kind='barh', ax=ax, stacked=True,
293
+ color=['#2ecc71', '#f39c12', '#e74c3c'], edgecolor='white')
294
+ ax.set_title('Fire Size Distribution by Cause', fontsize=14, fontweight='bold')
295
+ ax.set_xlabel('Percentage')
296
+ ax.legend(title='Size Category', loc='lower right')
297
+ ax.invert_yaxis()
298
+
299
+ plt.tight_layout()
300
+ plt.savefig(FIGURES_DIR / 'cause_by_size.png', dpi=150, bbox_inches='tight')
301
+ plt.close()
302
+ print(f"\n Saved: cause_by_size.png")
303
+
304
+
305
+ def analyze_owner_distribution(df: pd.DataFrame) -> None:
306
+ """Analyze land owner distribution."""
307
+ print("\n" + "="*60)
308
+ print("LAND OWNER ANALYSIS")
309
+ print("="*60)
310
+
311
+ owner_dist = df['OWNER_DESCR'].value_counts()
312
+ print("\nFires by Land Owner:")
313
+ for owner, count in owner_dist.head(10).items():
314
+ pct = count / len(df) * 100
315
+ print(f" {owner}: {count:,} ({pct:.1f}%)")
316
+
317
+
318
+ def main():
319
+ """Main EDA pipeline."""
320
+ print("\n" + "="*60)
321
+ print("EXPLORATORY DATA ANALYSIS")
322
+ print("="*60)
323
+
324
+ # Create figures directory
325
+ FIGURES_DIR.mkdir(parents=True, exist_ok=True)
326
+
327
+ # Load data
328
+ df = load_data()
329
+
330
+ # Run analyses
331
+ analyze_class_distribution(df)
332
+ analyze_geographic_distribution(df)
333
+ analyze_temporal_patterns(df)
334
+ analyze_missing_values(df)
335
+ analyze_cause_by_size(df)
336
+ analyze_owner_distribution(df)
337
+
338
+ print("\n" + "="*60)
339
+ print("✓ EDA Complete!")
340
+ print(f" Figures saved to: {FIGURES_DIR}")
341
+ print("="*60 + "\n")
342
+
343
+
344
+ if __name__ == "__main__":
345
+ main()
scripts/03_preprocess.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 03: Data Preprocessing
3
+
4
+ This script preprocesses the raw wildfire data:
5
+ - Creates ordinal target variable (3 classes: Small, Medium, Large)
6
+ - Drops irrelevant columns (IDs, text fields, redundant info)
7
+ - Handles missing values
8
+ - Encodes categorical variables
9
+ - Splits data into train/test sets (stratified)
10
+
11
+ Usage:
12
+ python scripts/03_preprocess.py
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.preprocessing import LabelEncoder
22
+
23
+ # Add project root to path
24
+ project_root = Path(__file__).parent.parent
25
+ sys.path.insert(0, str(project_root))
26
+
27
+ from config.config import (
28
+ RAW_PARQUET,
29
+ PROCESSED_PARQUET,
30
+ TRAIN_PARQUET,
31
+ TEST_PARQUET,
32
+ PROCESSED_DATA_DIR,
33
+ FIRE_SIZE_CLASS_MAPPING,
34
+ TARGET_CLASS_NAMES,
35
+ TARGET_COLUMN,
36
+ COLUMNS_TO_DROP,
37
+ CATEGORICAL_FEATURES,
38
+ RANDOM_STATE,
39
+ TEST_SIZE
40
+ )
41
+
42
+
43
+ def load_data() -> pd.DataFrame:
44
+ """Load the raw parquet data."""
45
+ print("Loading raw data...")
46
+ df = pd.read_parquet(RAW_PARQUET)
47
+ print(f" Loaded {len(df):,} records with {len(df.columns)} columns")
48
+ return df
49
+
50
+
51
+ def create_target_variable(df: pd.DataFrame) -> pd.DataFrame:
52
+ """Create ordinal target variable from FIRE_SIZE_CLASS."""
53
+ print("\nCreating ordinal target variable...")
54
+
55
+ # Map original classes to ordinal (0, 1, 2)
56
+ df[TARGET_COLUMN] = df['FIRE_SIZE_CLASS'].map(FIRE_SIZE_CLASS_MAPPING)
57
+
58
+ # Check for unmapped values
59
+ unmapped = df[TARGET_COLUMN].isna().sum()
60
+ if unmapped > 0:
61
+ print(f" Warning: {unmapped} records could not be mapped. Dropping...")
62
+ df = df.dropna(subset=[TARGET_COLUMN])
63
+
64
+ df[TARGET_COLUMN] = df[TARGET_COLUMN].astype(int)
65
+
66
+ # Print distribution
67
+ print("\n Target Variable Distribution:")
68
+ for val in sorted(df[TARGET_COLUMN].unique()):
69
+ count = (df[TARGET_COLUMN] == val).sum()
70
+ pct = count / len(df) * 100
71
+ print(f" {val} ({TARGET_CLASS_NAMES[val]}): {count:,} ({pct:.2f}%)")
72
+
73
+ return df
74
+
75
+
76
+ def drop_irrelevant_columns(df: pd.DataFrame) -> pd.DataFrame:
77
+ """Drop columns not useful for prediction."""
78
+ print("\nDropping irrelevant columns...")
79
+
80
+ # Get columns that exist in the dataframe
81
+ cols_to_drop = [col for col in COLUMNS_TO_DROP if col in df.columns]
82
+
83
+ print(f" Dropping {len(cols_to_drop)} columns:")
84
+ for col in cols_to_drop[:10]:
85
+ print(f" - {col}")
86
+ if len(cols_to_drop) > 10:
87
+ print(f" ... and {len(cols_to_drop) - 10} more")
88
+
89
+ df = df.drop(columns=cols_to_drop, errors='ignore')
90
+ print(f" Remaining columns: {len(df.columns)}")
91
+
92
+ return df
93
+
94
+
95
+ def handle_missing_values(df: pd.DataFrame) -> pd.DataFrame:
96
+ """Handle missing values in the dataset."""
97
+ print("\nHandling missing values...")
98
+
99
+ initial_rows = len(df)
100
+
101
+ # Check missing in essential columns
102
+ essential_cols = ['LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY', TARGET_COLUMN]
103
+ for col in essential_cols:
104
+ if col in df.columns:
105
+ missing = df[col].isna().sum()
106
+ if missing > 0:
107
+ print(f" {col}: {missing} missing values")
108
+
109
+ # Drop rows with missing essential values
110
+ df = df.dropna(subset=[c for c in essential_cols if c in df.columns])
111
+
112
+ # For categorical features, fill with 'Unknown'
113
+ for col in CATEGORICAL_FEATURES:
114
+ if col in df.columns:
115
+ missing = df[col].isna().sum()
116
+ if missing > 0:
117
+ df[col] = df[col].fillna('Unknown')
118
+ print(f" {col}: Filled {missing} missing with 'Unknown'")
119
+
120
+ rows_dropped = initial_rows - len(df)
121
+ print(f"\n Rows dropped due to missing essential values: {rows_dropped:,}")
122
+ print(f" Remaining rows: {len(df):,}")
123
+
124
+ return df
125
+
126
+
127
+ def encode_categorical_features(df: pd.DataFrame) -> tuple[pd.DataFrame, dict]:
128
+ """Encode categorical features using Label Encoding."""
129
+ print("\nEncoding categorical features...")
130
+
131
+ encoders = {}
132
+
133
+ for col in CATEGORICAL_FEATURES:
134
+ if col in df.columns:
135
+ le = LabelEncoder()
136
+ df[f'{col}_encoded'] = le.fit_transform(df[col].astype(str))
137
+ encoders[col] = le
138
+
139
+ n_categories = len(le.classes_)
140
+ print(f" {col}: {n_categories} categories")
141
+
142
+ return df, encoders
143
+
144
+
145
+ def select_features(df: pd.DataFrame) -> pd.DataFrame:
146
+ """Select features for modeling."""
147
+ print("\nSelecting features for modeling...")
148
+
149
+ # Features to keep
150
+ feature_cols = [
151
+ # Numerical
152
+ 'LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY',
153
+ # Encoded categorical
154
+ 'NWCG_REPORTING_AGENCY_encoded',
155
+ 'STAT_CAUSE_DESCR_encoded',
156
+ 'STATE_encoded',
157
+ 'OWNER_DESCR_encoded',
158
+ # Target
159
+ TARGET_COLUMN
160
+ ]
161
+
162
+ # Keep only columns that exist
163
+ available_cols = [col for col in feature_cols if col in df.columns]
164
+
165
+ # Also keep original categorical columns for reference
166
+ original_cats = [col for col in CATEGORICAL_FEATURES if col in df.columns]
167
+
168
+ all_cols = available_cols + original_cats
169
+ all_cols = list(dict.fromkeys(all_cols)) # Remove duplicates, preserve order
170
+
171
+ df = df[all_cols]
172
+
173
+ print(f" Selected {len(available_cols)} feature columns + target")
174
+ print(f" Final columns: {list(df.columns)}")
175
+
176
+ return df
177
+
178
+
179
+ def split_data(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
180
+ """Split data into train and test sets."""
181
+ print("\nSplitting data into train/test sets...")
182
+
183
+ train_df, test_df = train_test_split(
184
+ df,
185
+ test_size=TEST_SIZE,
186
+ random_state=RANDOM_STATE,
187
+ stratify=df[TARGET_COLUMN]
188
+ )
189
+
190
+ print(f" Train set: {len(train_df):,} rows ({100*(1-TEST_SIZE):.0f}%)")
191
+ print(f" Test set: {len(test_df):,} rows ({100*TEST_SIZE:.0f}%)")
192
+
193
+ # Verify stratification
194
+ print("\n Target distribution in splits:")
195
+ for name, data in [('Train', train_df), ('Test', test_df)]:
196
+ dist = data[TARGET_COLUMN].value_counts(normalize=True).sort_index() * 100
197
+ dist_str = ", ".join([f"{TARGET_CLASS_NAMES[i]}: {v:.1f}%" for i, v in dist.items()])
198
+ print(f" {name}: {dist_str}")
199
+
200
+ return train_df, test_df
201
+
202
+
203
+ def save_data(df: pd.DataFrame, train_df: pd.DataFrame, test_df: pd.DataFrame) -> None:
204
+ """Save processed data to parquet files."""
205
+ print("\nSaving processed data...")
206
+
207
+ # Create directory if needed
208
+ PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
209
+
210
+ # Save full processed data
211
+ df.to_parquet(PROCESSED_PARQUET, index=False)
212
+ print(f" Full processed data: {PROCESSED_PARQUET}")
213
+
214
+ # Save train/test splits
215
+ train_df.to_parquet(TRAIN_PARQUET, index=False)
216
+ print(f" Train data: {TRAIN_PARQUET}")
217
+
218
+ test_df.to_parquet(TEST_PARQUET, index=False)
219
+ print(f" Test data: {TEST_PARQUET}")
220
+
221
+
222
+ def print_summary(df: pd.DataFrame) -> None:
223
+ """Print preprocessing summary."""
224
+ print("\n" + "="*60)
225
+ print("PREPROCESSING SUMMARY")
226
+ print("="*60)
227
+
228
+ print(f"\nDataset shape: {df.shape}")
229
+ print(f"\nColumn types:")
230
+ print(df.dtypes.value_counts().to_string())
231
+
232
+ print(f"\nFeature statistics:")
233
+ numerical_cols = df.select_dtypes(include=[np.number]).columns
234
+ for col in numerical_cols:
235
+ if col != TARGET_COLUMN:
236
+ print(f" {col}:")
237
+ print(f" Range: [{df[col].min():.2f}, {df[col].max():.2f}]")
238
+ print(f" Mean: {df[col].mean():.2f}, Std: {df[col].std():.2f}")
239
+
240
+
241
+ def main():
242
+ """Main preprocessing pipeline."""
243
+ print("\n" + "="*60)
244
+ print("DATA PREPROCESSING")
245
+ print("="*60)
246
+
247
+ # Load data
248
+ df = load_data()
249
+
250
+ # Create target variable
251
+ df = create_target_variable(df)
252
+
253
+ # Drop irrelevant columns
254
+ df = drop_irrelevant_columns(df)
255
+
256
+ # Handle missing values
257
+ df = handle_missing_values(df)
258
+
259
+ # Encode categorical features
260
+ df, encoders = encode_categorical_features(df)
261
+
262
+ # Select features
263
+ df = select_features(df)
264
+
265
+ # Split data
266
+ train_df, test_df = split_data(df)
267
+
268
+ # Save data
269
+ save_data(df, train_df, test_df)
270
+
271
+ # Print summary
272
+ print_summary(df)
273
+
274
+ print("\n" + "="*60)
275
+ print("✓ Preprocessing Complete!")
276
+ print("="*60 + "\n")
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
scripts/04_feature_engineering.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 04: Feature Engineering
3
+
4
+ This script creates additional features for the model:
5
+ - Temporal features (month, season, day of week)
6
+ - Geospatial features (lat/lon bins, clustering, interactions)
7
+ - Coordinate transformations
8
+
9
+ Usage:
10
+ python scripts/04_feature_engineering.py
11
+ """
12
+
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from sklearn.cluster import KMeans
19
+ from sklearn.preprocessing import StandardScaler
20
+
21
+ # Add project root to path
22
+ project_root = Path(__file__).parent.parent
23
+ sys.path.insert(0, str(project_root))
24
+
25
+ from config.config import (
26
+ TRAIN_PARQUET,
27
+ TEST_PARQUET,
28
+ FEATURES_PARQUET,
29
+ PROCESSED_DATA_DIR,
30
+ TARGET_COLUMN,
31
+ N_GEO_CLUSTERS,
32
+ LAT_BINS,
33
+ LON_BINS,
34
+ RANDOM_STATE
35
+ )
36
+
37
+
38
+ def load_data() -> tuple[pd.DataFrame, pd.DataFrame]:
39
+ """Load train and test data."""
40
+ print("Loading data...")
41
+ train_df = pd.read_parquet(TRAIN_PARQUET)
42
+ test_df = pd.read_parquet(TEST_PARQUET)
43
+ print(f" Train: {len(train_df):,} rows")
44
+ print(f" Test: {len(test_df):,} rows")
45
+ return train_df, test_df
46
+
47
+
48
+ def create_temporal_features(df: pd.DataFrame) -> pd.DataFrame:
49
+ """Create temporal features from DISCOVERY_DOY."""
50
+ print("\nCreating temporal features...")
51
+
52
+ # Convert day of year to datetime for feature extraction
53
+ # Using a non-leap year as reference
54
+ reference_year = 2001
55
+ df['temp_date'] = pd.to_datetime(
56
+ df['DISCOVERY_DOY'].astype(int).astype(str) + f'-{reference_year}',
57
+ format='%j-%Y',
58
+ errors='coerce'
59
+ )
60
+
61
+ # Handle invalid dates
62
+ invalid_dates = df['temp_date'].isna().sum()
63
+ if invalid_dates > 0:
64
+ print(f" Warning: {invalid_dates} invalid day of year values")
65
+ # Fill with median day
66
+ median_doy = df['DISCOVERY_DOY'].median()
67
+ df.loc[df['temp_date'].isna(), 'temp_date'] = pd.to_datetime(
68
+ f'{int(median_doy)}-{reference_year}', format='%j-%Y'
69
+ )
70
+
71
+ # Extract features
72
+ df['month'] = df['temp_date'].dt.month
73
+ df['day_of_week'] = df['temp_date'].dt.dayofweek # 0=Monday, 6=Sunday
74
+ df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
75
+
76
+ # Season (1=Winter, 2=Spring, 3=Summer, 4=Fall)
77
+ df['season'] = df['month'].apply(lambda m:
78
+ 1 if m in [12, 1, 2] else
79
+ 2 if m in [3, 4, 5] else
80
+ 3 if m in [6, 7, 8] else 4
81
+ )
82
+
83
+ # Fire season indicator (peak fire months: June-October)
84
+ df['is_fire_season'] = df['month'].isin([6, 7, 8, 9, 10]).astype(int)
85
+
86
+ # Drop temporary date column
87
+ df = df.drop(columns=['temp_date'])
88
+
89
+ print(" Created: month, day_of_week, is_weekend, season, is_fire_season")
90
+
91
+ return df
92
+
93
+
94
+ def create_geospatial_features(train_df: pd.DataFrame, test_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, KMeans]:
95
+ """Create geospatial features from coordinates."""
96
+ print("\nCreating geospatial features...")
97
+
98
+ # 1. Latitude/Longitude bins
99
+ print(" Creating coordinate bins...")
100
+
101
+ # Define bin edges based on continental US bounds
102
+ lat_min, lat_max = 24.0, 50.0
103
+ lon_min, lon_max = -125.0, -66.0
104
+
105
+ lat_edges = np.linspace(lat_min, lat_max, LAT_BINS + 1)
106
+ lon_edges = np.linspace(lon_min, lon_max, LON_BINS + 1)
107
+
108
+ for df in [train_df, test_df]:
109
+ df['lat_bin'] = pd.cut(df['LATITUDE'], bins=lat_edges, labels=False, include_lowest=True)
110
+ df['lon_bin'] = pd.cut(df['LONGITUDE'], bins=lon_edges, labels=False, include_lowest=True)
111
+
112
+ # Fill NaN bins (locations outside continental US) with nearest bin
113
+ df['lat_bin'] = df['lat_bin'].fillna(df['lat_bin'].median()).astype(int)
114
+ df['lon_bin'] = df['lon_bin'].fillna(df['lon_bin'].median()).astype(int)
115
+
116
+ # 2. Geographic clustering using K-Means
117
+ print(f" Fitting K-Means clustering (k={N_GEO_CLUSTERS})...")
118
+
119
+ # Prepare coordinates for clustering
120
+ train_coords = train_df[['LATITUDE', 'LONGITUDE']].values
121
+ test_coords = test_df[['LATITUDE', 'LONGITUDE']].values
122
+
123
+ # Scale coordinates
124
+ scaler = StandardScaler()
125
+ train_coords_scaled = scaler.fit_transform(train_coords)
126
+ test_coords_scaled = scaler.transform(test_coords)
127
+
128
+ # Fit K-Means on train data
129
+ kmeans = KMeans(n_clusters=N_GEO_CLUSTERS, random_state=RANDOM_STATE, n_init=10)
130
+ train_df['geo_cluster'] = kmeans.fit_predict(train_coords_scaled)
131
+ test_df['geo_cluster'] = kmeans.predict(test_coords_scaled)
132
+
133
+ print(f" Cluster distribution (train):")
134
+ cluster_dist = train_df['geo_cluster'].value_counts().sort_index()
135
+ for cluster, count in cluster_dist.items():
136
+ pct = count / len(train_df) * 100
137
+ if pct >= 3: # Only show clusters with >= 3%
138
+ print(f" Cluster {cluster}: {count:,} ({pct:.1f}%)")
139
+
140
+ # 3. Coordinate interactions
141
+ print(" Creating coordinate interactions...")
142
+
143
+ for df in [train_df, test_df]:
144
+ # Quadratic terms (captures non-linear patterns)
145
+ df['lat_squared'] = df['LATITUDE'] ** 2
146
+ df['lon_squared'] = df['LONGITUDE'] ** 2
147
+ df['lat_lon_interaction'] = df['LATITUDE'] * df['LONGITUDE']
148
+
149
+ # Distance from geographic center of continental US
150
+ # Approximate center: 39.8°N, 98.6°W
151
+ center_lat, center_lon = 39.8, -98.6
152
+ df['dist_from_center'] = np.sqrt(
153
+ (df['LATITUDE'] - center_lat) ** 2 +
154
+ (df['LONGITUDE'] - center_lon) ** 2
155
+ )
156
+
157
+ print(" Created: lat_bin, lon_bin, geo_cluster, lat_squared, lon_squared, lat_lon_interaction, dist_from_center")
158
+
159
+ return train_df, test_df, kmeans
160
+
161
+
162
+ def create_cyclical_features(df: pd.DataFrame) -> pd.DataFrame:
163
+ """Create cyclical encoding for periodic features."""
164
+ print("\nCreating cyclical features...")
165
+
166
+ # Cyclical encoding for month (captures January-December continuity)
167
+ df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
168
+ df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
169
+
170
+ # Cyclical encoding for day of year
171
+ df['doy_sin'] = np.sin(2 * np.pi * df['DISCOVERY_DOY'] / 365)
172
+ df['doy_cos'] = np.cos(2 * np.pi * df['DISCOVERY_DOY'] / 365)
173
+
174
+ # Cyclical encoding for day of week
175
+ df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
176
+ df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
177
+
178
+ print(" Created: month_sin/cos, doy_sin/cos, dow_sin/cos")
179
+
180
+ return df
181
+
182
+
183
+ def create_year_features(df: pd.DataFrame) -> pd.DataFrame:
184
+ """Create year-based features."""
185
+ print("\nCreating year features...")
186
+
187
+ # Normalized year (0-1 scale for 1992-2015)
188
+ min_year, max_year = 1992, 2015
189
+ df['year_normalized'] = (df['FIRE_YEAR'] - min_year) / (max_year - min_year)
190
+
191
+ # Years since start
192
+ df['years_since_1992'] = df['FIRE_YEAR'] - min_year
193
+
194
+ print(" Created: year_normalized, years_since_1992")
195
+
196
+ return df
197
+
198
+
199
+ def get_feature_columns(df: pd.DataFrame) -> list:
200
+ """Get list of feature columns for modeling."""
201
+ # Exclude target, original categorical text columns, and intermediate columns
202
+ exclude_cols = [
203
+ TARGET_COLUMN,
204
+ 'NWCG_REPORTING_AGENCY', 'STAT_CAUSE_DESCR', 'STATE', 'OWNER_DESCR',
205
+ 'COUNTY' # If present
206
+ ]
207
+
208
+ feature_cols = [col for col in df.columns if col not in exclude_cols]
209
+ return feature_cols
210
+
211
+
212
+ def save_data(train_df: pd.DataFrame, test_df: pd.DataFrame) -> None:
213
+ """Save feature-engineered data."""
214
+ print("\nSaving feature-engineered data...")
215
+
216
+ # Overwrite train/test files with new features
217
+ train_df.to_parquet(TRAIN_PARQUET, index=False)
218
+ test_df.to_parquet(TEST_PARQUET, index=False)
219
+
220
+ print(f" Train data: {TRAIN_PARQUET}")
221
+ print(f" Test data: {TEST_PARQUET}")
222
+
223
+ # Also save combined for reference
224
+ combined = pd.concat([train_df, test_df], ignore_index=True)
225
+ combined.to_parquet(FEATURES_PARQUET, index=False)
226
+ print(f" Combined data: {FEATURES_PARQUET}")
227
+
228
+
229
+ def print_summary(train_df: pd.DataFrame) -> None:
230
+ """Print feature engineering summary."""
231
+ print("\n" + "="*60)
232
+ print("FEATURE ENGINEERING SUMMARY")
233
+ print("="*60)
234
+
235
+ feature_cols = get_feature_columns(train_df)
236
+
237
+ print(f"\nTotal features: {len(feature_cols)}")
238
+ print("\nFeature list:")
239
+
240
+ # Group features by type
241
+ temporal = [c for c in feature_cols if c in ['month', 'day_of_week', 'is_weekend', 'season', 'is_fire_season',
242
+ 'month_sin', 'month_cos', 'doy_sin', 'doy_cos', 'dow_sin', 'dow_cos']]
243
+ geospatial = [c for c in feature_cols if c in ['lat_bin', 'lon_bin', 'geo_cluster', 'lat_squared', 'lon_squared',
244
+ 'lat_lon_interaction', 'dist_from_center', 'LATITUDE', 'LONGITUDE']]
245
+ year_feats = [c for c in feature_cols if c in ['FIRE_YEAR', 'year_normalized', 'years_since_1992', 'DISCOVERY_DOY']]
246
+ encoded = [c for c in feature_cols if c.endswith('_encoded')]
247
+
248
+ print(f"\n Temporal ({len(temporal)}): {temporal}")
249
+ print(f"\n Geospatial ({len(geospatial)}): {geospatial}")
250
+ print(f"\n Year-based ({len(year_feats)}): {year_feats}")
251
+ print(f"\n Encoded categorical ({len(encoded)}): {encoded}")
252
+
253
+
254
+ def main():
255
+ """Main feature engineering pipeline."""
256
+ print("\n" + "="*60)
257
+ print("FEATURE ENGINEERING")
258
+ print("="*60)
259
+
260
+ # Load data
261
+ train_df, test_df = load_data()
262
+
263
+ # Create temporal features
264
+ train_df = create_temporal_features(train_df)
265
+ test_df = create_temporal_features(test_df)
266
+
267
+ # Create geospatial features
268
+ train_df, test_df, kmeans = create_geospatial_features(train_df, test_df)
269
+
270
+ # Create cyclical features
271
+ train_df = create_cyclical_features(train_df)
272
+ test_df = create_cyclical_features(test_df)
273
+
274
+ # Create year features
275
+ train_df = create_year_features(train_df)
276
+ test_df = create_year_features(test_df)
277
+
278
+ # Save data
279
+ save_data(train_df, test_df)
280
+
281
+ # Print summary
282
+ print_summary(train_df)
283
+
284
+ print("\n" + "="*60)
285
+ print("✓ Feature Engineering Complete!")
286
+ print("="*60 + "\n")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
scripts/05_train_model.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 05: Model Training
3
+
4
+ This script trains the ordinal classification model:
5
+ - Uses LightGBM for multi-class ordinal classification
6
+ - Implements class weighting for imbalanced data
7
+ - Performs cross-validation
8
+ - Includes hyperparameter tuning with Optuna
9
+ - Saves the trained model
10
+
11
+ Ordinal Classification Approach:
12
+ Since fire size classes have a natural order (Small < Medium < Large),
13
+ we use ordinal-aware training with cumulative link model concepts.
14
+
15
+ Usage:
16
+ python scripts/05_train_model.py [--tune]
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ import joblib
25
+ import lightgbm as lgb
26
+ import numpy as np
27
+ import optuna
28
+ import pandas as pd
29
+ from sklearn.metrics import (
30
+ accuracy_score,
31
+ balanced_accuracy_score,
32
+ classification_report,
33
+ cohen_kappa_score,
34
+ f1_score,
35
+ )
36
+ from sklearn.model_selection import StratifiedKFold
37
+ from sklearn.utils.class_weight import compute_class_weight
38
+
39
+ # Add project root to path
40
+ project_root = Path(__file__).parent.parent
41
+ sys.path.insert(0, str(project_root))
42
+
43
+ from config.config import (
44
+ TRAIN_PARQUET,
45
+ TEST_PARQUET,
46
+ MODELS_DIR,
47
+ TARGET_COLUMN,
48
+ TARGET_CLASS_NAMES,
49
+ LIGHTGBM_PARAMS,
50
+ OPTUNA_SEARCH_SPACE,
51
+ N_OPTUNA_TRIALS,
52
+ N_FOLDS,
53
+ RANDOM_STATE,
54
+ USE_CLASS_WEIGHTS,
55
+ PRIMARY_METRIC
56
+ )
57
+
58
+
59
+ def load_data() -> tuple[pd.DataFrame, pd.DataFrame]:
60
+ """Load train and test data."""
61
+ print("Loading data...")
62
+ train_df = pd.read_parquet(TRAIN_PARQUET)
63
+ test_df = pd.read_parquet(TEST_PARQUET)
64
+ print(f" Train: {len(train_df):,} rows")
65
+ print(f" Test: {len(test_df):,} rows")
66
+ return train_df, test_df
67
+
68
+
69
+ def get_feature_columns(df: pd.DataFrame) -> list:
70
+ """Get list of feature columns for modeling."""
71
+ exclude_cols = [
72
+ TARGET_COLUMN,
73
+ 'NWCG_REPORTING_AGENCY', 'STAT_CAUSE_DESCR', 'STATE', 'OWNER_DESCR',
74
+ 'COUNTY'
75
+ ]
76
+ feature_cols = [col for col in df.columns if col not in exclude_cols]
77
+ return feature_cols
78
+
79
+
80
+ def prepare_data(train_df: pd.DataFrame, test_df: pd.DataFrame) -> tuple:
81
+ """Prepare features and targets for training."""
82
+ print("\nPreparing data...")
83
+
84
+ feature_cols = get_feature_columns(train_df)
85
+
86
+ X_train = train_df[feature_cols].values
87
+ y_train = train_df[TARGET_COLUMN].values
88
+ X_test = test_df[feature_cols].values
89
+ y_test = test_df[TARGET_COLUMN].values
90
+
91
+ print(f" Features: {len(feature_cols)}")
92
+ print(f" Feature columns: {feature_cols}")
93
+
94
+ return X_train, y_train, X_test, y_test, feature_cols
95
+
96
+
97
+ def compute_weights(y_train: np.ndarray) -> np.ndarray:
98
+ """Compute sample weights for class imbalance."""
99
+ print("\nComputing class weights...")
100
+
101
+ classes = np.unique(y_train)
102
+ class_weights = compute_class_weight(
103
+ class_weight='balanced',
104
+ classes=classes,
105
+ y=y_train
106
+ )
107
+
108
+ weight_dict = dict(zip(classes, class_weights))
109
+ print(f" Class weights: {weight_dict}")
110
+
111
+ # Create sample weights array
112
+ sample_weights = np.array([weight_dict[y] for y in y_train])
113
+
114
+ return sample_weights
115
+
116
+
117
+ def evaluate_model(y_true: np.ndarray, y_pred: np.ndarray, prefix: str = "") -> dict:
118
+ """Evaluate model predictions."""
119
+ metrics = {
120
+ 'accuracy': accuracy_score(y_true, y_pred),
121
+ 'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
122
+ 'macro_f1': f1_score(y_true, y_pred, average='macro'),
123
+ 'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
124
+ 'cohen_kappa': cohen_kappa_score(y_true, y_pred, weights='linear') # Linear weights for ordinal
125
+ }
126
+
127
+ if prefix:
128
+ print(f"\n{prefix} Metrics:")
129
+ else:
130
+ print("\nMetrics:")
131
+
132
+ for name, value in metrics.items():
133
+ print(f" {name}: {value:.4f}")
134
+
135
+ return metrics
136
+
137
+
138
+ def cross_validate(X: np.ndarray, y: np.ndarray, params: dict,
139
+ sample_weights: np.ndarray = None) -> tuple[float, float]:
140
+ """Perform cross-validation and return mean and std of primary metric."""
141
+
142
+ skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE)
143
+ scores = []
144
+
145
+ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
146
+ X_fold_train, X_fold_val = X[train_idx], X[val_idx]
147
+ y_fold_train, y_fold_val = y[train_idx], y[val_idx]
148
+
149
+ if sample_weights is not None:
150
+ weights_fold = sample_weights[train_idx]
151
+ else:
152
+ weights_fold = None
153
+
154
+ # Create LightGBM datasets
155
+ train_data = lgb.Dataset(X_fold_train, label=y_fold_train, weight=weights_fold)
156
+ val_data = lgb.Dataset(X_fold_val, label=y_fold_val, reference=train_data)
157
+
158
+ # Train model
159
+ model = lgb.train(
160
+ params,
161
+ train_data,
162
+ num_boost_round=params.get('n_estimators', 500),
163
+ valid_sets=[val_data],
164
+ callbacks=[lgb.early_stopping(50, verbose=False)]
165
+ )
166
+
167
+ # Predict
168
+ y_pred = model.predict(X_fold_val)
169
+ y_pred_class = np.argmax(y_pred, axis=1)
170
+
171
+ # Score
172
+ score = f1_score(y_fold_val, y_pred_class, average='macro')
173
+ scores.append(score)
174
+
175
+ return np.mean(scores), np.std(scores)
176
+
177
+
178
+ def objective(trial: optuna.Trial, X: np.ndarray, y: np.ndarray,
179
+ sample_weights: np.ndarray) -> float:
180
+ """Optuna objective function for hyperparameter tuning."""
181
+
182
+ params = LIGHTGBM_PARAMS.copy()
183
+
184
+ # Sample hyperparameters
185
+ params['n_estimators'] = trial.suggest_int('n_estimators', *OPTUNA_SEARCH_SPACE['n_estimators'])
186
+ params['max_depth'] = trial.suggest_int('max_depth', *OPTUNA_SEARCH_SPACE['max_depth'])
187
+ params['learning_rate'] = trial.suggest_float('learning_rate', *OPTUNA_SEARCH_SPACE['learning_rate'], log=True)
188
+ params['num_leaves'] = trial.suggest_int('num_leaves', *OPTUNA_SEARCH_SPACE['num_leaves'])
189
+ params['min_child_samples'] = trial.suggest_int('min_child_samples', *OPTUNA_SEARCH_SPACE['min_child_samples'])
190
+ params['subsample'] = trial.suggest_float('subsample', *OPTUNA_SEARCH_SPACE['subsample'])
191
+ params['colsample_bytree'] = trial.suggest_float('colsample_bytree', *OPTUNA_SEARCH_SPACE['colsample_bytree'])
192
+ params['reg_alpha'] = trial.suggest_float('reg_alpha', *OPTUNA_SEARCH_SPACE['reg_alpha'])
193
+ params['reg_lambda'] = trial.suggest_float('reg_lambda', *OPTUNA_SEARCH_SPACE['reg_lambda'])
194
+
195
+ # Cross-validate
196
+ mean_score, _ = cross_validate(X, y, params, sample_weights)
197
+
198
+ return mean_score
199
+
200
+
201
+ def tune_hyperparameters(X: np.ndarray, y: np.ndarray,
202
+ sample_weights: np.ndarray) -> dict:
203
+ """Tune hyperparameters using Optuna."""
204
+ print("\n" + "="*60)
205
+ print("HYPERPARAMETER TUNING")
206
+ print("="*60)
207
+
208
+ print(f"\nRunning {N_OPTUNA_TRIALS} Optuna trials...")
209
+
210
+ # Create study
211
+ study = optuna.create_study(
212
+ direction='maximize',
213
+ sampler=optuna.samplers.TPESampler(seed=RANDOM_STATE)
214
+ )
215
+
216
+ # Optimize
217
+ study.optimize(
218
+ lambda trial: objective(trial, X, y, sample_weights),
219
+ n_trials=N_OPTUNA_TRIALS,
220
+ show_progress_bar=True
221
+ )
222
+
223
+ print(f"\nBest trial:")
224
+ print(f" Value (macro F1): {study.best_trial.value:.4f}")
225
+ print(f" Params: {study.best_trial.params}")
226
+
227
+ # Merge best params with base params
228
+ best_params = LIGHTGBM_PARAMS.copy()
229
+ best_params.update(study.best_trial.params)
230
+
231
+ best_params_path = MODELS_DIR / 'best_params.json'
232
+ with open(best_params_path, 'w') as f:
233
+ json.dump(study.best_trial.params, f)
234
+ print(f" Best params saved: {best_params_path}")
235
+
236
+ return best_params
237
+
238
+
239
+ def train_final_model(X_train: np.ndarray, y_train: np.ndarray,
240
+ X_test: np.ndarray, y_test: np.ndarray,
241
+ params: dict, sample_weights: np.ndarray,
242
+ feature_names: list) -> lgb.Booster:
243
+ """Train final model on full training data."""
244
+ print("\n" + "="*60)
245
+ print("TRAINING FINAL MODEL")
246
+ print("="*60)
247
+
248
+ # Create datasets
249
+ train_data = lgb.Dataset(X_train, label=y_train, weight=sample_weights,
250
+ feature_name=feature_names)
251
+ val_data = lgb.Dataset(X_test, label=y_test, reference=train_data,
252
+ feature_name=feature_names)
253
+
254
+ # Train
255
+ print("\nTraining...")
256
+ model = lgb.train(
257
+ params,
258
+ train_data,
259
+ num_boost_round=params.get('n_estimators', 2000),
260
+ valid_sets=[train_data, val_data],
261
+ valid_names=['train', 'test'],
262
+ callbacks=[
263
+ lgb.early_stopping(50, verbose=True),
264
+ lgb.log_evaluation(period=50)
265
+ ]
266
+ )
267
+
268
+ # Evaluate
269
+ print("\n" + "-"*40)
270
+
271
+ # Train predictions
272
+ y_train_pred = np.argmax(model.predict(X_train), axis=1)
273
+ evaluate_model(y_train, y_train_pred, "Train")
274
+
275
+ # Test predictions
276
+ y_test_pred = np.argmax(model.predict(X_test), axis=1)
277
+ test_metrics = evaluate_model(y_test, y_test_pred, "Test")
278
+
279
+ # Classification report
280
+ print("\nClassification Report (Test):")
281
+ print(classification_report(y_test, y_test_pred, target_names=TARGET_CLASS_NAMES))
282
+
283
+ return model, test_metrics
284
+
285
+
286
+ def save_model(model: lgb.Booster, params: dict, feature_names: list, metrics: dict) -> None:
287
+ """Save trained model and metadata."""
288
+ print("\nSaving model...")
289
+
290
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
291
+
292
+ # Save LightGBM model
293
+ model_path = MODELS_DIR / 'wildfire_model.txt'
294
+ model.save_model(str(model_path))
295
+ print(f" Model: {model_path}")
296
+
297
+ # Save metadata
298
+ metadata = {
299
+ 'params': params,
300
+ 'feature_names': feature_names,
301
+ 'metrics': metrics,
302
+ 'target_classes': TARGET_CLASS_NAMES
303
+ }
304
+ metadata_path = MODELS_DIR / 'model_metadata.joblib'
305
+ joblib.dump(metadata, metadata_path)
306
+ print(f" Metadata: {metadata_path}")
307
+
308
+
309
+ def main():
310
+ """Main training pipeline."""
311
+ # Parse arguments
312
+ parser = argparse.ArgumentParser(description='Train wildfire classification model')
313
+ parser.add_argument('--tune', action='store_true', help='Run hyperparameter tuning')
314
+ args = parser.parse_args()
315
+
316
+ print("\n" + "="*60)
317
+ print("MODEL TRAINING")
318
+ print("="*60)
319
+
320
+ # Load data
321
+ train_df, test_df = load_data()
322
+
323
+ # Prepare data
324
+ X_train, y_train, X_test, y_test, feature_cols = prepare_data(train_df, test_df)
325
+
326
+ # Compute class weights
327
+ sample_weights = None
328
+ if USE_CLASS_WEIGHTS:
329
+ sample_weights = compute_weights(y_train)
330
+
331
+ # Get parameters
332
+ if args.tune:
333
+ params = tune_hyperparameters(X_train, y_train, sample_weights)
334
+ else:
335
+ best_params_path = MODELS_DIR / 'best_params.json'
336
+ if best_params_path.exists():
337
+ # Load saved best params
338
+ with open(best_params_path, 'r') as f:
339
+ tuned_params = json.load(f)
340
+ params = LIGHTGBM_PARAMS.copy()
341
+ params.update(tuned_params)
342
+ print(f"Loaded best params from {best_params_path}")
343
+ else:
344
+ # Fallback to defaults
345
+ params = LIGHTGBM_PARAMS.copy()
346
+ params['n_estimators'] = 500
347
+ params['max_depth'] = 8
348
+ params['learning_rate'] = 0.05
349
+ params['num_leaves'] = 64
350
+ params['min_child_samples'] = 50
351
+ params['subsample'] = 0.8
352
+ params['colsample_bytree'] = 0.8
353
+ print("No saved params found; using defaults")
354
+
355
+ # Train final model
356
+ model, metrics = train_final_model(
357
+ X_train, y_train, X_test, y_test,
358
+ params, sample_weights, feature_cols
359
+ )
360
+
361
+ # Save model
362
+ save_model(model, params, feature_cols, metrics)
363
+
364
+ print("\n" + "="*60)
365
+ print("✓ Training Complete!")
366
+ print("="*60 + "\n")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
scripts/06_evaluate.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 06: Model Evaluation
3
+
4
+ This script performs comprehensive evaluation of the trained model:
5
+ - Confusion matrix visualization
6
+ - Per-class metrics analysis
7
+ - Ordinal-specific metrics (linear weighted kappa)
8
+ - SHAP feature importance analysis
9
+ - Error analysis
10
+
11
+ Usage:
12
+ python scripts/06_evaluate.py
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ import joblib
19
+ import lightgbm as lgb
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import pandas as pd
23
+ import seaborn as sns
24
+ import shap
25
+ from sklearn.metrics import (
26
+ accuracy_score,
27
+ balanced_accuracy_score,
28
+ classification_report,
29
+ cohen_kappa_score,
30
+ confusion_matrix,
31
+ f1_score,
32
+ precision_score,
33
+ recall_score,
34
+ )
35
+
36
+ # Add project root to path
37
+ project_root = Path(__file__).parent.parent
38
+ sys.path.insert(0, str(project_root))
39
+
40
+ from config.config import (
41
+ TEST_PARQUET,
42
+ TRAIN_PARQUET,
43
+ MODELS_DIR,
44
+ FIGURES_DIR,
45
+ TARGET_COLUMN,
46
+ TARGET_CLASS_NAMES
47
+ )
48
+
49
+ # Set style
50
+ plt.style.use('seaborn-v0_8-whitegrid')
51
+
52
+
53
+ def load_model_and_data() -> tuple:
54
+ """Load trained model, metadata, and test data."""
55
+ print("Loading model and data...")
56
+
57
+ # Load model
58
+ model_path = MODELS_DIR / 'wildfire_model.txt'
59
+ model = lgb.Booster(model_file=str(model_path))
60
+ print(f" Model: {model_path}")
61
+
62
+ # Load metadata
63
+ metadata_path = MODELS_DIR / 'model_metadata.joblib'
64
+ metadata = joblib.load(metadata_path)
65
+ print(f" Metadata: {metadata_path}")
66
+
67
+ # Load test data
68
+ test_df = pd.read_parquet(TEST_PARQUET)
69
+ train_df = pd.read_parquet(TRAIN_PARQUET)
70
+ print(f" Test data: {len(test_df):,} rows")
71
+
72
+ return model, metadata, train_df, test_df
73
+
74
+
75
+ def prepare_data(df: pd.DataFrame, feature_names: list) -> tuple:
76
+ """Prepare features and target from dataframe."""
77
+ X = df[feature_names].values
78
+ y = df[TARGET_COLUMN].values
79
+ return X, y
80
+
81
+
82
+ def compute_all_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> dict:
83
+ """Compute comprehensive metrics."""
84
+
85
+ metrics = {
86
+ # Standard metrics
87
+ 'accuracy': accuracy_score(y_true, y_pred),
88
+ 'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
89
+ 'macro_f1': f1_score(y_true, y_pred, average='macro'),
90
+ 'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
91
+ 'macro_precision': precision_score(y_true, y_pred, average='macro'),
92
+ 'macro_recall': recall_score(y_true, y_pred, average='macro'),
93
+
94
+ # Ordinal-specific: Linear weighted Cohen's Kappa
95
+ # Penalizes predictions farther from true class
96
+ 'cohen_kappa_linear': cohen_kappa_score(y_true, y_pred, weights='linear'),
97
+ 'cohen_kappa_quadratic': cohen_kappa_score(y_true, y_pred, weights='quadratic'),
98
+
99
+ # Per-class metrics
100
+ 'per_class_precision': precision_score(y_true, y_pred, average=None),
101
+ 'per_class_recall': recall_score(y_true, y_pred, average=None),
102
+ 'per_class_f1': f1_score(y_true, y_pred, average=None)
103
+ }
104
+
105
+ return metrics
106
+
107
+
108
+ def print_metrics(metrics: dict) -> None:
109
+ """Print metrics in a formatted way."""
110
+ print("\n" + "="*60)
111
+ print("EVALUATION METRICS")
112
+ print("="*60)
113
+
114
+ print("\nOverall Metrics:")
115
+ print(f" Accuracy: {metrics['accuracy']:.4f}")
116
+ print(f" Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
117
+ print(f" Macro F1: {metrics['macro_f1']:.4f}")
118
+ print(f" Weighted F1: {metrics['weighted_f1']:.4f}")
119
+ print(f" Macro Precision: {metrics['macro_precision']:.4f}")
120
+ print(f" Macro Recall: {metrics['macro_recall']:.4f}")
121
+
122
+ print("\nOrdinal Metrics (penalize distance from true class):")
123
+ print(f" Cohen's Kappa (Linear): {metrics['cohen_kappa_linear']:.4f}")
124
+ print(f" Cohen's Kappa (Quadratic): {metrics['cohen_kappa_quadratic']:.4f}")
125
+
126
+ print("\nPer-Class Metrics:")
127
+ print(f" {'Class':<10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
128
+ print(f" {'-'*40}")
129
+ for i, name in enumerate(TARGET_CLASS_NAMES):
130
+ print(f" {name:<10} {metrics['per_class_precision'][i]:>10.4f} "
131
+ f"{metrics['per_class_recall'][i]:>10.4f} {metrics['per_class_f1'][i]:>10.4f}")
132
+
133
+
134
+ def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path) -> None:
135
+ """Plot and save confusion matrix."""
136
+ print("\nGenerating confusion matrix...")
137
+
138
+ cm = confusion_matrix(y_true, y_pred)
139
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
140
+
141
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
142
+
143
+ # Raw counts
144
+ ax1 = axes[0]
145
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
146
+ xticklabels=TARGET_CLASS_NAMES, yticklabels=TARGET_CLASS_NAMES)
147
+ ax1.set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
148
+ ax1.set_xlabel('Predicted')
149
+ ax1.set_ylabel('Actual')
150
+
151
+ # Normalized (percentages)
152
+ ax2 = axes[1]
153
+ sns.heatmap(cm_normalized, annot=True, fmt='.1%', cmap='Blues', ax=ax2,
154
+ xticklabels=TARGET_CLASS_NAMES, yticklabels=TARGET_CLASS_NAMES)
155
+ ax2.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
156
+ ax2.set_xlabel('Predicted')
157
+ ax2.set_ylabel('Actual')
158
+
159
+ plt.tight_layout()
160
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
161
+ plt.close()
162
+ print(f" Saved: {save_path}")
163
+
164
+
165
+ def plot_classification_report(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path) -> None:
166
+ """Plot classification metrics as bar chart."""
167
+ print("\nGenerating classification report plot...")
168
+
169
+ report = classification_report(y_true, y_pred, target_names=TARGET_CLASS_NAMES, output_dict=True)
170
+
171
+ # Convert to DataFrame
172
+ df_report = pd.DataFrame(report).T
173
+ df_report = df_report.drop(['accuracy', 'macro avg', 'weighted avg'], errors='ignore')
174
+ df_report = df_report[['precision', 'recall', 'f1-score']]
175
+
176
+ fig, ax = plt.subplots(figsize=(10, 6))
177
+
178
+ x = np.arange(len(TARGET_CLASS_NAMES))
179
+ width = 0.25
180
+
181
+ bars1 = ax.bar(x - width, df_report['precision'], width, label='Precision', color='#3498db')
182
+ bars2 = ax.bar(x, df_report['recall'], width, label='Recall', color='#2ecc71')
183
+ bars3 = ax.bar(x + width, df_report['f1-score'], width, label='F1-Score', color='#e74c3c')
184
+
185
+ ax.set_xlabel('Fire Size Class')
186
+ ax.set_ylabel('Score')
187
+ ax.set_title('Per-Class Classification Metrics', fontsize=14, fontweight='bold')
188
+ ax.set_xticks(x)
189
+ ax.set_xticklabels(TARGET_CLASS_NAMES)
190
+ ax.legend()
191
+ ax.set_ylim(0, 1.1)
192
+
193
+ # Add value labels
194
+ for bars in [bars1, bars2, bars3]:
195
+ for bar in bars:
196
+ height = bar.get_height()
197
+ ax.annotate(f'{height:.2f}',
198
+ xy=(bar.get_x() + bar.get_width() / 2, height),
199
+ xytext=(0, 3), textcoords="offset points",
200
+ ha='center', va='bottom', fontsize=8)
201
+
202
+ plt.tight_layout()
203
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
204
+ plt.close()
205
+ print(f" Saved: {save_path}")
206
+
207
+
208
+ def plot_shap_importance(model: lgb.Booster, X: np.ndarray,
209
+ feature_names: list, save_path: Path,
210
+ max_display: int = 20) -> None:
211
+ """Generate SHAP feature importance plots."""
212
+ print("\nGenerating SHAP analysis...")
213
+ print(f" X shape: {X.shape}")
214
+ print(f" Number of feature names: {len(feature_names)}")
215
+
216
+ # Use a sample for SHAP (faster computation)
217
+ sample_size = min(5000, len(X))
218
+ np.random.seed(42)
219
+ sample_idx = np.random.choice(len(X), sample_size, replace=False)
220
+ X_sample = X[sample_idx]
221
+
222
+ # Create explainer
223
+ explainer = shap.TreeExplainer(model)
224
+ shap_values = explainer.shap_values(X_sample)
225
+
226
+ # SHAP values is a list of arrays (one per class for multiclass)
227
+ # Average absolute SHAP values across all classes
228
+ if isinstance(shap_values, list):
229
+ # If it's a list of arrays, each array is (samples, features)
230
+ # We want the mean absolute value for each feature across all samples and all classes
231
+ mean_shap = np.mean([np.abs(sv).mean(axis=0) for sv in shap_values], axis=0)
232
+ else:
233
+ # Handle case where shap_values is a single array (samples, features * classes)
234
+ # or (samples, features)
235
+ mean_shap = np.abs(shap_values).mean(axis=0)
236
+
237
+ # If we have a multiple of features, it's likely multiclass flattened
238
+ num_feats = len(feature_names)
239
+ if mean_shap.size > num_feats and mean_shap.size % num_feats == 0:
240
+ n_classes = mean_shap.size // num_feats
241
+ print(f" Aggregating SHAP values for {n_classes} classes...")
242
+ mean_shap = mean_shap.reshape(n_classes, num_feats).mean(axis=0)
243
+
244
+ # Ensure mean_shap is 1D
245
+ if mean_shap.ndim > 1:
246
+ mean_shap = mean_shap.flatten()
247
+
248
+ print(f" Mean SHAP shape: {mean_shap.shape}")
249
+
250
+ # Handle mismatch between feature_names and mean_shap length
251
+ if len(feature_names) != mean_shap.size:
252
+ print(f" WARNING: Feature names ({len(feature_names)}) != SHAP values ({mean_shap.size})")
253
+ # Trim to match
254
+ n = min(len(feature_names), mean_shap.size)
255
+ feature_names = feature_names[:n]
256
+ mean_shap = mean_shap[:n]
257
+ print(f" Trimmed to {n} features")
258
+
259
+ # Create feature importance DataFrame
260
+ importance_df = pd.DataFrame({
261
+ 'feature': feature_names,
262
+ 'importance': mean_shap
263
+ }).sort_values('importance', ascending=True)
264
+
265
+ # Plot 1: Feature Importance Bar Chart
266
+ plt.figure(figsize=(10, 8))
267
+ top_features = importance_df.tail(max_display)
268
+ plt.barh(top_features['feature'], top_features['importance'], color='steelblue')
269
+ plt.xlabel('Mean |SHAP Value|')
270
+ plt.title(f'Top {max_display} Feature Importance (SHAP)', fontsize=14, fontweight='bold')
271
+ plt.grid(axis='x', alpha=0.3)
272
+ plt.tight_layout()
273
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
274
+ plt.close()
275
+ print(f" Saved importance plot: {save_path}")
276
+
277
+ # Plot 2: SHAP Summary Plot (Large Fires)
278
+ # Extract SHAP values for Large fire class (class index 2)
279
+ shap_values_large = None
280
+ num_feats = len(feature_names)
281
+
282
+ if isinstance(shap_values, list) and len(shap_values) > 2:
283
+ # Already a list of arrays per class
284
+ shap_values_large = shap_values[2]
285
+ elif isinstance(shap_values, np.ndarray):
286
+ # Single array - need to reshape if it's (samples, features * classes)
287
+ if shap_values.shape[1] == num_feats * 3:
288
+ # Reshape from (samples, features*classes) to (samples, classes, features)
289
+ # Then extract class 2 (Large fires)
290
+ reshaped = shap_values.reshape(shap_values.shape[0], 3, num_feats)
291
+ shap_values_large = reshaped[:, 2, :] # Class 2 = Large
292
+ print(f" Extracted Large fire SHAP values: {shap_values_large.shape}")
293
+ elif shap_values.shape[1] == num_feats:
294
+ # Binary or single output - use as-is
295
+ shap_values_large = shap_values
296
+
297
+ if shap_values_large is not None:
298
+ summary_path = save_path.parent / f"{save_path.stem}_summary{save_path.suffix}"
299
+ plt.figure(figsize=(10, 8))
300
+ try:
301
+ print(" Generating SHAP summary plot...")
302
+ shap.summary_plot(shap_values_large, X_sample, feature_names=feature_names,
303
+ max_display=max_display, show=False)
304
+ plt.title('SHAP Summary: Large Fire Class', fontsize=14, fontweight='bold')
305
+ plt.tight_layout()
306
+ plt.savefig(summary_path, dpi=150, bbox_inches='tight')
307
+ print(f" Saved summary plot: {summary_path}")
308
+ except Exception as e:
309
+ print(f" Could not generate summary plot: {e}")
310
+ plt.close()
311
+ else:
312
+ print(" Skipping summary plot (could not extract Large class SHAP values)")
313
+
314
+ # Print top features
315
+ print("\n Top 10 Most Important Features:")
316
+ for _, row in importance_df.tail(10).iloc[::-1].iterrows():
317
+ print(f" {row['feature']}: {row['importance']:.4f}")
318
+
319
+ return importance_df
320
+
321
+
322
+ def analyze_errors(test_df: pd.DataFrame, y_true: np.ndarray,
323
+ y_pred: np.ndarray, save_path: Path) -> None:
324
+ """Analyze misclassifications."""
325
+ print("\nAnalyzing misclassifications...")
326
+
327
+ # Add predictions to dataframe
328
+ test_df = test_df.copy()
329
+ test_df['predicted'] = y_pred
330
+ test_df['correct'] = y_true == y_pred
331
+
332
+ errors = test_df[~test_df['correct']]
333
+
334
+ print(f"\n Total errors: {len(errors):,} ({len(errors)/len(test_df)*100:.1f}%)")
335
+
336
+ # Error types
337
+ print("\n Error Distribution:")
338
+ for true_class in range(3):
339
+ for pred_class in range(3):
340
+ if true_class != pred_class:
341
+ count = ((y_true == true_class) & (y_pred == pred_class)).sum()
342
+ if count > 0:
343
+ pct = count / len(errors) * 100
344
+ true_name = TARGET_CLASS_NAMES[true_class]
345
+ pred_name = TARGET_CLASS_NAMES[pred_class]
346
+ print(f" {true_name} → {pred_name}: {count:,} ({pct:.1f}%)")
347
+
348
+ # Adjacent vs non-adjacent errors (important for ordinal)
349
+ adjacent_errors = 0
350
+ non_adjacent_errors = 0
351
+
352
+ for true_class, pred_class in zip(y_true[y_true != y_pred], y_pred[y_true != y_pred]):
353
+ if abs(true_class - pred_class) == 1:
354
+ adjacent_errors += 1
355
+ else:
356
+ non_adjacent_errors += 1
357
+
358
+ print(f"\n Ordinal Error Analysis:")
359
+ print(f" Adjacent errors (off by 1): {adjacent_errors:,} ({adjacent_errors/len(errors)*100:.1f}%)")
360
+ print(f" Non-adjacent errors (off by 2): {non_adjacent_errors:,} ({non_adjacent_errors/len(errors)*100:.1f}%)")
361
+
362
+
363
+ def plot_prediction_distribution(y_true: np.ndarray, y_pred: np.ndarray,
364
+ y_proba: np.ndarray, save_path: Path) -> None:
365
+ """Plot prediction probability distributions."""
366
+ print("\nGenerating prediction distribution plots...")
367
+
368
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
369
+
370
+ for i, (ax, class_name) in enumerate(zip(axes, TARGET_CLASS_NAMES)):
371
+ # Get probabilities for this class
372
+ proba = y_proba[:, i]
373
+
374
+ # Split by actual class
375
+ for true_class in range(3):
376
+ mask = y_true == true_class
377
+ ax.hist(proba[mask], bins=50, alpha=0.5,
378
+ label=f'Actual: {TARGET_CLASS_NAMES[true_class]}', density=True)
379
+
380
+ ax.set_xlabel(f'P({class_name})')
381
+ ax.set_ylabel('Density')
382
+ ax.set_title(f'Predicted Probability: {class_name}', fontweight='bold')
383
+ ax.legend(fontsize=8)
384
+
385
+ plt.tight_layout()
386
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
387
+ plt.close()
388
+ print(f" Saved: {save_path}")
389
+
390
+
391
+ def main():
392
+ """Main evaluation pipeline."""
393
+ print("\n" + "="*60)
394
+ print("MODEL EVALUATION")
395
+ print("="*60)
396
+
397
+ # Create figures directory
398
+ FIGURES_DIR.mkdir(parents=True, exist_ok=True)
399
+
400
+ # Load model and data
401
+ model, metadata, train_df, test_df = load_model_and_data()
402
+ feature_names = metadata['feature_names']
403
+
404
+ # Prepare data
405
+ X_test, y_test = prepare_data(test_df, feature_names)
406
+ X_train, y_train = prepare_data(train_df, feature_names)
407
+
408
+ # Make predictions
409
+ y_proba = model.predict(X_test)
410
+ y_pred = np.argmax(y_proba, axis=1)
411
+
412
+ # Compute metrics
413
+ metrics = compute_all_metrics(y_test, y_pred, y_proba)
414
+ print_metrics(metrics)
415
+
416
+ # Generate plots
417
+ plot_confusion_matrix(y_test, y_pred, FIGURES_DIR / 'confusion_matrix.png')
418
+ plot_classification_report(y_test, y_pred, FIGURES_DIR / 'classification_metrics.png')
419
+ plot_prediction_distribution(y_test, y_pred, y_proba, FIGURES_DIR / 'prediction_distribution.png')
420
+
421
+ # SHAP analysis
422
+ importance_df = plot_shap_importance(model, X_test, feature_names,
423
+ FIGURES_DIR / 'shap_importance.png')
424
+
425
+ # Error analysis
426
+ analyze_errors(test_df, y_test, y_pred, FIGURES_DIR / 'error_analysis.png')
427
+
428
+ # Save importance rankings
429
+ importance_df.to_csv(FIGURES_DIR / 'feature_importance.csv', index=False)
430
+
431
+ print("\n" + "="*60)
432
+ print("✓ Evaluation Complete!")
433
+ print(f" Figures saved to: {FIGURES_DIR}")
434
+ print("="*60 + "\n")
435
+
436
+
437
+ if __name__ == "__main__":
438
+ main()
scripts/07_predict.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script 07: Prediction Pipeline
3
+
4
+ This script provides inference capabilities:
5
+ - Load trained model
6
+ - Preprocess new data
7
+ - Generate predictions with probabilities
8
+ - Can be used as a module or standalone script
9
+
10
+ Usage:
11
+ # Single prediction
12
+ python scripts/07_predict.py --lat 34.05 --lon -118.24 --state CA --cause "Debris Burning" --month 7
13
+
14
+ # Batch prediction from CSV
15
+ python scripts/07_predict.py --input new_fires.csv --output predictions.csv
16
+ """
17
+
18
+ import argparse
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+ import joblib
24
+ import lightgbm as lgb
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ # Add project root to path
29
+ project_root = Path(__file__).parent.parent
30
+ sys.path.insert(0, str(project_root))
31
+
32
+ from config.config import (
33
+ MODELS_DIR,
34
+ TARGET_CLASS_NAMES,
35
+ FIRE_SIZE_CLASS_MAPPING,
36
+ CATEGORICAL_FEATURES,
37
+ N_GEO_CLUSTERS,
38
+ LAT_BINS,
39
+ LON_BINS
40
+ )
41
+
42
+
43
+ class WildfirePredictor:
44
+ """Wildfire size class predictor."""
45
+
46
+ def __init__(self, model_dir: Path = MODELS_DIR):
47
+ """Initialize predictor with trained model."""
48
+ self.model_dir = model_dir
49
+ self.model = None
50
+ self.metadata = None
51
+ self.feature_names = None
52
+ self.encoders = {}
53
+
54
+ self._load_model()
55
+
56
+ def _load_model(self) -> None:
57
+ """Load trained model and metadata."""
58
+ model_path = self.model_dir / 'wildfire_model.txt'
59
+ metadata_path = self.model_dir / 'model_metadata.joblib'
60
+
61
+ if not model_path.exists():
62
+ raise FileNotFoundError(f"Model not found at {model_path}. Run training first.")
63
+
64
+ self.model = lgb.Booster(model_file=str(model_path))
65
+ self.metadata = joblib.load(metadata_path)
66
+ self.feature_names = self.metadata['feature_names']
67
+
68
+ print(f"Loaded model with {len(self.feature_names)} features")
69
+
70
+ def _create_features(self, df: pd.DataFrame) -> pd.DataFrame:
71
+ """Create features for prediction."""
72
+ df = df.copy()
73
+
74
+ # Ensure required columns exist
75
+ required = ['LATITUDE', 'LONGITUDE', 'FIRE_YEAR', 'DISCOVERY_DOY']
76
+ for col in required:
77
+ if col not in df.columns:
78
+ raise ValueError(f"Missing required column: {col}")
79
+
80
+ # Temporal features
81
+ reference_year = 2001
82
+ df['temp_date'] = pd.to_datetime(
83
+ df['DISCOVERY_DOY'].astype(int).astype(str) + f'-{reference_year}',
84
+ format='%j-%Y',
85
+ errors='coerce'
86
+ )
87
+
88
+ df['month'] = df['temp_date'].dt.month
89
+ df['day_of_week'] = df['temp_date'].dt.dayofweek
90
+ df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
91
+ df['season'] = df['month'].apply(lambda m:
92
+ 1 if m in [12, 1, 2] else
93
+ 2 if m in [3, 4, 5] else
94
+ 3 if m in [6, 7, 8] else 4
95
+ )
96
+ df['is_fire_season'] = df['month'].isin([6, 7, 8, 9, 10]).astype(int)
97
+
98
+ # Cyclical features
99
+ df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
100
+ df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
101
+ df['doy_sin'] = np.sin(2 * np.pi * df['DISCOVERY_DOY'] / 365)
102
+ df['doy_cos'] = np.cos(2 * np.pi * df['DISCOVERY_DOY'] / 365)
103
+ df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
104
+ df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
105
+
106
+ # Year features
107
+ min_year, max_year = 1992, 2015
108
+ df['year_normalized'] = (df['FIRE_YEAR'] - min_year) / (max_year - min_year)
109
+ df['years_since_1992'] = df['FIRE_YEAR'] - min_year
110
+
111
+ # Geospatial features
112
+ lat_min, lat_max = 24.0, 50.0
113
+ lon_min, lon_max = -125.0, -66.0
114
+ lat_edges = np.linspace(lat_min, lat_max, LAT_BINS + 1)
115
+ lon_edges = np.linspace(lon_min, lon_max, LON_BINS + 1)
116
+
117
+ df['lat_bin'] = pd.cut(df['LATITUDE'], bins=lat_edges, labels=False, include_lowest=True)
118
+ df['lon_bin'] = pd.cut(df['LONGITUDE'], bins=lon_edges, labels=False, include_lowest=True)
119
+ df['lat_bin'] = df['lat_bin'].fillna(5).astype(int)
120
+ df['lon_bin'] = df['lon_bin'].fillna(5).astype(int)
121
+
122
+ # Coordinate features
123
+ df['lat_squared'] = df['LATITUDE'] ** 2
124
+ df['lon_squared'] = df['LONGITUDE'] ** 2
125
+ df['lat_lon_interaction'] = df['LATITUDE'] * df['LONGITUDE']
126
+
127
+ center_lat, center_lon = 39.8, -98.6
128
+ df['dist_from_center'] = np.sqrt(
129
+ (df['LATITUDE'] - center_lat) ** 2 +
130
+ (df['LONGITUDE'] - center_lon) ** 2
131
+ )
132
+
133
+ # Placeholder for geo_cluster (would need kmeans model)
134
+ df['geo_cluster'] = 0
135
+
136
+ # Drop temporary columns
137
+ df = df.drop(columns=['temp_date'], errors='ignore')
138
+
139
+ return df
140
+
141
+ def _encode_categoricals(self, df: pd.DataFrame) -> pd.DataFrame:
142
+ """Encode categorical variables."""
143
+ df = df.copy()
144
+
145
+ # Simple label encoding for inference
146
+ # In production, would need to use same encoders as training
147
+ for col in CATEGORICAL_FEATURES:
148
+ encoded_col = f'{col}_encoded'
149
+ if col in df.columns:
150
+ # Simple hash-based encoding as fallback
151
+ df[encoded_col] = df[col].astype(str).apply(lambda x: hash(x) % 100)
152
+ else:
153
+ df[encoded_col] = 0
154
+
155
+ return df
156
+
157
+ def preprocess(self, df: pd.DataFrame) -> np.ndarray:
158
+ """Preprocess data for prediction."""
159
+ df = self._create_features(df)
160
+ df = self._encode_categoricals(df)
161
+
162
+ # Select and order features to match training
163
+ missing_features = [f for f in self.feature_names if f not in df.columns]
164
+ if missing_features:
165
+ print(f"Warning: Missing features (filled with 0): {missing_features}")
166
+ for f in missing_features:
167
+ df[f] = 0
168
+
169
+ X = df[self.feature_names].values
170
+ return X
171
+
172
+ def predict(self, df: pd.DataFrame) -> pd.DataFrame:
173
+ """Generate predictions for input data."""
174
+ X = self.preprocess(df)
175
+
176
+ # Get probabilities
177
+ proba = self.model.predict(X)
178
+ pred_class = np.argmax(proba, axis=1)
179
+
180
+ # Create results dataframe
181
+ results = df.copy()
182
+ results['predicted_class'] = pred_class
183
+ results['predicted_label'] = [TARGET_CLASS_NAMES[c] for c in pred_class]
184
+ results['prob_small'] = proba[:, 0]
185
+ results['prob_medium'] = proba[:, 1]
186
+ results['prob_large'] = proba[:, 2]
187
+ results['confidence'] = np.max(proba, axis=1)
188
+
189
+ return results
190
+
191
+ def predict_single(self, latitude: float, longitude: float,
192
+ fire_year: int, discovery_doy: int,
193
+ state: str = 'Unknown',
194
+ cause: str = 'Unknown',
195
+ agency: str = 'Unknown',
196
+ owner: str = 'Unknown') -> dict:
197
+ """Predict for a single fire event."""
198
+
199
+ df = pd.DataFrame([{
200
+ 'LATITUDE': latitude,
201
+ 'LONGITUDE': longitude,
202
+ 'FIRE_YEAR': fire_year,
203
+ 'DISCOVERY_DOY': discovery_doy,
204
+ 'STATE': state,
205
+ 'STAT_CAUSE_DESCR': cause,
206
+ 'NWCG_REPORTING_AGENCY': agency,
207
+ 'OWNER_DESCR': owner
208
+ }])
209
+
210
+ result = self.predict(df).iloc[0]
211
+
212
+ return {
213
+ 'predicted_class': int(result['predicted_class']),
214
+ 'predicted_label': result['predicted_label'],
215
+ 'probabilities': {
216
+ 'Small': float(result['prob_small']),
217
+ 'Medium': float(result['prob_medium']),
218
+ 'Large': float(result['prob_large'])
219
+ },
220
+ 'confidence': float(result['confidence'])
221
+ }
222
+
223
+
224
+ def main():
225
+ """Main prediction script."""
226
+ parser = argparse.ArgumentParser(description='Wildfire size prediction')
227
+
228
+ # Single prediction arguments
229
+ parser.add_argument('--lat', type=float, help='Latitude')
230
+ parser.add_argument('--lon', type=float, help='Longitude')
231
+ parser.add_argument('--year', type=int, default=2015, help='Fire year')
232
+ parser.add_argument('--doy', type=int, default=200, help='Day of year')
233
+ parser.add_argument('--state', type=str, default='Unknown', help='State code')
234
+ parser.add_argument('--cause', type=str, default='Unknown', help='Fire cause')
235
+
236
+ # Batch prediction arguments
237
+ parser.add_argument('--input', type=str, help='Input CSV file for batch prediction')
238
+ parser.add_argument('--output', type=str, help='Output CSV file for predictions')
239
+
240
+ args = parser.parse_args()
241
+
242
+ # Initialize predictor
243
+ predictor = WildfirePredictor()
244
+
245
+ if args.input:
246
+ # Batch prediction
247
+ print(f"\nProcessing batch predictions from: {args.input}")
248
+ df = pd.read_csv(args.input)
249
+ results = predictor.predict(df)
250
+
251
+ output_path = args.output or 'predictions.csv'
252
+ results.to_csv(output_path, index=False)
253
+ print(f"Predictions saved to: {output_path}")
254
+
255
+ elif args.lat is not None and args.lon is not None:
256
+ # Single prediction
257
+ print("\n" + "="*60)
258
+ print("SINGLE FIRE PREDICTION")
259
+ print("="*60)
260
+
261
+ result = predictor.predict_single(
262
+ latitude=args.lat,
263
+ longitude=args.lon,
264
+ fire_year=args.year,
265
+ discovery_doy=args.doy,
266
+ state=args.state,
267
+ cause=args.cause
268
+ )
269
+
270
+ print(f"\nInput:")
271
+ print(f" Location: ({args.lat}, {args.lon})")
272
+ print(f" Year: {args.year}, Day of Year: {args.doy}")
273
+ print(f" State: {args.state}, Cause: {args.cause}")
274
+
275
+ print(f"\nPrediction:")
276
+ print(f" Class: {result['predicted_class']} ({result['predicted_label']})")
277
+ print(f" Confidence: {result['confidence']:.1%}")
278
+
279
+ print(f"\nProbabilities:")
280
+ for label, prob in result['probabilities'].items():
281
+ bar = '█' * int(prob * 20)
282
+ print(f" {label:>6}: {prob:>6.1%} {bar}")
283
+
284
+ else:
285
+ # Demo prediction
286
+ print("\n" + "="*60)
287
+ print("DEMO PREDICTION")
288
+ print("="*60)
289
+
290
+ # Example: Summer fire in California
291
+ result = predictor.predict_single(
292
+ latitude=34.05,
293
+ longitude=-118.24,
294
+ fire_year=2015,
295
+ discovery_doy=200, # Mid-July
296
+ state='CA',
297
+ cause='Debris Burning'
298
+ )
299
+
300
+ print("\nExample: Summer fire in Los Angeles area")
301
+ print(f" Predicted: {result['predicted_label']} (confidence: {result['confidence']:.1%})")
302
+ print(f" Probabilities: Small={result['probabilities']['Small']:.1%}, "
303
+ f"Medium={result['probabilities']['Medium']:.1%}, "
304
+ f"Large={result['probabilities']['Large']:.1%}")
305
+
306
+ print("\nUsage:")
307
+ print(" Single: python 07_predict.py --lat 34.05 --lon -118.24 --state CA --cause 'Lightning'")
308
+ print(" Batch: python 07_predict.py --input fires.csv --output predictions.csv")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()