Commit ยท
79a10a0
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes. ย
See raw diff
- .gitignore +6 -0
- Analysis_code/1.data_preprocessing/0.air_data_merge.ipynb +1469 -0
- Analysis_code/1.data_preprocessing/1.data_merge.ipynb +0 -0
- Analysis_code/1.data_preprocessing/3.make_train_test.ipynb +1099 -0
- Analysis_code/2.make_oversample_data/gpu0.log +0 -0
- Analysis_code/2.make_oversample_data/gpu1.log +0 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_1.py +316 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_2.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_3.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_1.py +316 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_2.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_3.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_1.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_2.py +317 -0
- Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_3.py +317 -0
- Analysis_code/2.make_oversample_data/run_ctgan_gpu0.bash +58 -0
- Analysis_code/2.make_oversample_data/run_ctgan_gpu1.bash +58 -0
- Analysis_code/2.make_oversample_data/smote_only/smote_sample_1.py +86 -0
- Analysis_code/2.make_oversample_data/smote_only/smote_sample_2.py +86 -0
- Analysis_code/2.make_oversample_data/smote_only/smote_sample_3.py +86 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_1.py +375 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_2.py +376 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_3.py +376 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_1.py +375 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_2.py +376 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_3.py +376 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_1.py +378 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_2.py +376 -0
- Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_3.py +376 -0
- Analysis_code/3.sampled_data_analysis/make_plot.py +659 -0
- Analysis_code/3.sampled_data_analysis/oversampling_model_hyperparameter.ipynb +574 -0
- Analysis_code/4.sampling_data_test/analysis.ipynb +244 -0
- Analysis_code/4.sampling_data_test/lgb_sampled_test.ipynb +0 -0
- Analysis_code/4.sampling_data_test/xgb_sampled_test.ipynb +0 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_busan.py +98 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_daegu.py +99 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_daejeon.py +99 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_gwangju.py +99 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_incheon.py +99 -0
- Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_seoul.py +99 -0
- Analysis_code/5.optima/deepgbm_pure/utils.py +720 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_busan.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_daegu.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_daejeon.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_gwangju.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_incheon.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_seoul.py +97 -0
- Analysis_code/5.optima/deepgbm_smote/utils.py +720 -0
- Analysis_code/5.optima/deepgbm_smotenc_ctgan20000/deepgbm_smotenc_ctgan20000_busan.py +97 -0
- Analysis_code/5.optima/deepgbm_smotenc_ctgan20000/deepgbm_smotenc_ctgan20000_daegu.py +97 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/*
|
| 2 |
+
Analysis_code/3.sampling_data_test/images/*
|
| 3 |
+
Analysis_code/3.sampled_data_analysis/images/*
|
| 4 |
+
__pycache__/
|
| 5 |
+
Analysis_code/optimization_history/*
|
| 6 |
+
Analysis_code/save_model/*
|
Analysis_code/1.data_preprocessing/0.air_data_merge.ipynb
ADDED
|
@@ -0,0 +1,1469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"Package Version\n",
|
| 13 |
+
"----------------------------- ------------------\n",
|
| 14 |
+
"absl-py 1.4.0\n",
|
| 15 |
+
"accelerate 0.24.0.dev0\n",
|
| 16 |
+
"aiofiles 23.2.1\n",
|
| 17 |
+
"aiohttp 3.8.5\n",
|
| 18 |
+
"aiosignal 1.3.1\n",
|
| 19 |
+
"alabaster 0.7.13\n",
|
| 20 |
+
"albumentations 1.3.1\n",
|
| 21 |
+
"alembic 1.12.0\n",
|
| 22 |
+
"annotated-types 0.5.0\n",
|
| 23 |
+
"anyio 4.0.0\n",
|
| 24 |
+
"appdirs 1.4.4\n",
|
| 25 |
+
"argon2-cffi 23.1.0\n",
|
| 26 |
+
"argon2-cffi-bindings 21.2.0\n",
|
| 27 |
+
"array-record 0.4.1\n",
|
| 28 |
+
"arrow 1.2.3\n",
|
| 29 |
+
"asttokens 2.4.0\n",
|
| 30 |
+
"astunparse 1.6.3\n",
|
| 31 |
+
"async-lru 2.0.4\n",
|
| 32 |
+
"async-timeout 4.0.3\n",
|
| 33 |
+
"attrs 23.1.0\n",
|
| 34 |
+
"audioread 3.0.0\n",
|
| 35 |
+
"Babel 2.12.1\n",
|
| 36 |
+
"backcall 0.2.0\n",
|
| 37 |
+
"backoff 2.2.1\n",
|
| 38 |
+
"bcrypt 4.0.1\n",
|
| 39 |
+
"beautifulsoup4 4.12.2\n",
|
| 40 |
+
"bitsandbytes 0.41.1\n",
|
| 41 |
+
"black 23.9.1\n",
|
| 42 |
+
"bleach 6.0.0\n",
|
| 43 |
+
"blis 0.7.10\n",
|
| 44 |
+
"branca 0.6.0\n",
|
| 45 |
+
"Brotli 1.1.0\n",
|
| 46 |
+
"cachetools 5.3.1\n",
|
| 47 |
+
"captum 0.6.0\n",
|
| 48 |
+
"catalogue 2.0.9\n",
|
| 49 |
+
"catalyst 22.4\n",
|
| 50 |
+
"catboost 1.2.1.1\n",
|
| 51 |
+
"certifi 2023.7.22\n",
|
| 52 |
+
"cffi 1.15.1\n",
|
| 53 |
+
"charset-normalizer 3.2.0\n",
|
| 54 |
+
"chroma-hnswlib 0.7.3\n",
|
| 55 |
+
"chromadb 0.4.10\n",
|
| 56 |
+
"click 8.1.7\n",
|
| 57 |
+
"cloudpickle 2.2.1\n",
|
| 58 |
+
"cmaes 0.10.0\n",
|
| 59 |
+
"cmake 3.27.5\n",
|
| 60 |
+
"cmdstanpy 1.1.0\n",
|
| 61 |
+
"coloredlogs 15.0.1\n",
|
| 62 |
+
"colorlog 6.7.0\n",
|
| 63 |
+
"comm 0.1.4\n",
|
| 64 |
+
"confection 0.1.3\n",
|
| 65 |
+
"contourpy 1.1.1\n",
|
| 66 |
+
"convertdate 2.4.0\n",
|
| 67 |
+
"cubinlinker-cu11 0.3.0.post1\n",
|
| 68 |
+
"cuda-python 11.8.2\n",
|
| 69 |
+
"cudf-cu11 23.8.0\n",
|
| 70 |
+
"cuml-cu11 23.8.0\n",
|
| 71 |
+
"cupy-cuda11x 12.2.0\n",
|
| 72 |
+
"curio 1.6\n",
|
| 73 |
+
"customized-konlpy 0.0.64\n",
|
| 74 |
+
"cycler 0.11.0\n",
|
| 75 |
+
"cymem 2.0.8\n",
|
| 76 |
+
"cysignals 1.11.2\n",
|
| 77 |
+
"Cython 3.0.2\n",
|
| 78 |
+
"dask 2023.7.1\n",
|
| 79 |
+
"dask-cuda 23.8.0\n",
|
| 80 |
+
"dask-cudf-cu11 23.8.0\n",
|
| 81 |
+
"dataclasses-json 0.5.14\n",
|
| 82 |
+
"datasets 2.14.5\n",
|
| 83 |
+
"debugpy 1.8.0\n",
|
| 84 |
+
"decorator 5.1.1\n",
|
| 85 |
+
"defusedxml 0.7.1\n",
|
| 86 |
+
"dill 0.3.7\n",
|
| 87 |
+
"distributed 2023.7.1\n",
|
| 88 |
+
"dm-tree 0.1.8\n",
|
| 89 |
+
"dnspython 2.4.2\n",
|
| 90 |
+
"docker-pycreds 0.4.0\n",
|
| 91 |
+
"docrepr 0.2.0\n",
|
| 92 |
+
"docutils 0.18.1\n",
|
| 93 |
+
"duckduckgo-search 3.8.5\n",
|
| 94 |
+
"entrypoints 0.4\n",
|
| 95 |
+
"ephem 4.1.4\n",
|
| 96 |
+
"etils 1.4.1\n",
|
| 97 |
+
"exceptiongroup 1.1.3\n",
|
| 98 |
+
"executing 1.2.0\n",
|
| 99 |
+
"fastai 2.7.12\n",
|
| 100 |
+
"fastapi 0.99.1\n",
|
| 101 |
+
"fastcore 1.5.29\n",
|
| 102 |
+
"fastdownload 0.0.7\n",
|
| 103 |
+
"fastjsonschema 2.18.0\n",
|
| 104 |
+
"fastprogress 1.0.3\n",
|
| 105 |
+
"fastrlock 0.8.2\n",
|
| 106 |
+
"fasttext 0.9.2\n",
|
| 107 |
+
"filelock 3.12.4\n",
|
| 108 |
+
"flatbuffers 23.5.26\n",
|
| 109 |
+
"folium 0.14.0\n",
|
| 110 |
+
"fonttools 4.42.1\n",
|
| 111 |
+
"fqdn 1.5.1\n",
|
| 112 |
+
"frozenlist 1.4.0\n",
|
| 113 |
+
"fsspec 2023.6.0\n",
|
| 114 |
+
"future 0.18.3\n",
|
| 115 |
+
"fvcore 0.1.5.post20221221\n",
|
| 116 |
+
"gast 0.4.0\n",
|
| 117 |
+
"gensim 4.3.2\n",
|
| 118 |
+
"gitdb 4.0.10\n",
|
| 119 |
+
"GitPython 3.1.36\n",
|
| 120 |
+
"google-auth 2.23.0\n",
|
| 121 |
+
"google-auth-oauthlib 1.0.0\n",
|
| 122 |
+
"google-pasta 0.2.0\n",
|
| 123 |
+
"googleapis-common-protos 1.60.0\n",
|
| 124 |
+
"graphviz 0.20.1\n",
|
| 125 |
+
"greenlet 2.0.2\n",
|
| 126 |
+
"grpcio 1.58.0\n",
|
| 127 |
+
"h11 0.14.0\n",
|
| 128 |
+
"h2 4.1.0\n",
|
| 129 |
+
"h5py 3.9.0\n",
|
| 130 |
+
"holidays 0.33\n",
|
| 131 |
+
"hpack 4.0.0\n",
|
| 132 |
+
"httpcore 0.18.0\n",
|
| 133 |
+
"httptools 0.6.0\n",
|
| 134 |
+
"httpx 0.25.0\n",
|
| 135 |
+
"huggingface-hub 0.16.4\n",
|
| 136 |
+
"humanfriendly 10.0\n",
|
| 137 |
+
"hydra-slayer 0.4.1\n",
|
| 138 |
+
"hyperframe 6.0.1\n",
|
| 139 |
+
"hyperopt 0.2.7\n",
|
| 140 |
+
"idna 3.4\n",
|
| 141 |
+
"imageio 2.31.3\n",
|
| 142 |
+
"imagesize 1.4.1\n",
|
| 143 |
+
"importlib-metadata 6.8.0\n",
|
| 144 |
+
"importlib-resources 6.0.1\n",
|
| 145 |
+
"iniconfig 2.0.0\n",
|
| 146 |
+
"intel-openmp 2023.2.0\n",
|
| 147 |
+
"iopath 0.1.10\n",
|
| 148 |
+
"ipykernel 6.25.2\n",
|
| 149 |
+
"ipyparallel 8.6.1\n",
|
| 150 |
+
"ipython 8.15.0\n",
|
| 151 |
+
"ipython-genutils 0.2.0\n",
|
| 152 |
+
"ipywidgets 8.1.1\n",
|
| 153 |
+
"isoduration 20.11.0\n",
|
| 154 |
+
"jedi 0.19.0\n",
|
| 155 |
+
"Jinja2 3.1.2\n",
|
| 156 |
+
"joblib 1.3.2\n",
|
| 157 |
+
"JPype1 1.4.1\n",
|
| 158 |
+
"JPype1-py3 0.5.5.4\n",
|
| 159 |
+
"json5 0.9.14\n",
|
| 160 |
+
"jsonpointer 2.4\n",
|
| 161 |
+
"jsonschema 4.19.0\n",
|
| 162 |
+
"jsonschema-specifications 2023.7.1\n",
|
| 163 |
+
"jupyter 1.0.0\n",
|
| 164 |
+
"jupyter_client 8.3.1\n",
|
| 165 |
+
"jupyter-console 6.6.3\n",
|
| 166 |
+
"jupyter_core 5.3.1\n",
|
| 167 |
+
"jupyter-events 0.7.0\n",
|
| 168 |
+
"jupyter-lsp 2.2.0\n",
|
| 169 |
+
"jupyter_server 2.7.3\n",
|
| 170 |
+
"jupyter_server_terminals 0.4.4\n",
|
| 171 |
+
"jupyterlab 4.0.6\n",
|
| 172 |
+
"jupyterlab-pygments 0.2.2\n",
|
| 173 |
+
"jupyterlab_server 2.25.0\n",
|
| 174 |
+
"jupyterlab-widgets 3.0.9\n",
|
| 175 |
+
"jupyterthemes 0.20.0\n",
|
| 176 |
+
"kaggle 1.5.16\n",
|
| 177 |
+
"keras 2.13.1\n",
|
| 178 |
+
"kiwisolver 1.4.5\n",
|
| 179 |
+
"konlpy 0.6.0\n",
|
| 180 |
+
"kornia 0.7.0\n",
|
| 181 |
+
"krwordrank 1.0.3\n",
|
| 182 |
+
"langchain 0.0.295\n",
|
| 183 |
+
"langcodes 3.3.0\n",
|
| 184 |
+
"langsmith 0.0.38\n",
|
| 185 |
+
"lazy_loader 0.3\n",
|
| 186 |
+
"lesscpy 0.15.1\n",
|
| 187 |
+
"libclang 16.0.6\n",
|
| 188 |
+
"librosa 0.10.1\n",
|
| 189 |
+
"lightgbm 4.1.0\n",
|
| 190 |
+
"lit 16.0.6\n",
|
| 191 |
+
"llvmlite 0.40.1\n",
|
| 192 |
+
"locket 1.0.0\n",
|
| 193 |
+
"loguru 0.7.2\n",
|
| 194 |
+
"LunarCalendar 0.0.9\n",
|
| 195 |
+
"lxml 4.9.3\n",
|
| 196 |
+
"Mako 1.2.4\n",
|
| 197 |
+
"Markdown 3.4.4\n",
|
| 198 |
+
"MarkupSafe 2.1.3\n",
|
| 199 |
+
"marshmallow 3.20.1\n",
|
| 200 |
+
"matplotlib 3.8.0\n",
|
| 201 |
+
"matplotlib-inline 0.1.6\n",
|
| 202 |
+
"mecab-python3 1.0.7\n",
|
| 203 |
+
"missingno 0.5.2\n",
|
| 204 |
+
"mistune 3.0.1\n",
|
| 205 |
+
"mkl 2023.2.0\n",
|
| 206 |
+
"mlxtend 0.22.0\n",
|
| 207 |
+
"monotonic 1.6\n",
|
| 208 |
+
"mpmath 1.3.0\n",
|
| 209 |
+
"msgpack 1.0.5\n",
|
| 210 |
+
"multidict 6.0.4\n",
|
| 211 |
+
"multiprocess 0.70.15\n",
|
| 212 |
+
"murmurhash 1.0.10\n",
|
| 213 |
+
"mypy-extensions 1.0.0\n",
|
| 214 |
+
"nbclient 0.8.0\n",
|
| 215 |
+
"nbconvert 7.8.0\n",
|
| 216 |
+
"nbformat 5.9.2\n",
|
| 217 |
+
"nest-asyncio 1.5.8\n",
|
| 218 |
+
"networkx 3.1\n",
|
| 219 |
+
"nltk 3.8.1\n",
|
| 220 |
+
"notebook 7.0.3\n",
|
| 221 |
+
"notebook_shim 0.2.3\n",
|
| 222 |
+
"numba 0.57.1\n",
|
| 223 |
+
"numexpr 2.8.6\n",
|
| 224 |
+
"numpy 1.24.3\n",
|
| 225 |
+
"nvidia-cublas-cu11 11.10.3.66\n",
|
| 226 |
+
"nvidia-cuda-cupti-cu11 11.7.101\n",
|
| 227 |
+
"nvidia-cuda-nvrtc-cu11 11.7.99\n",
|
| 228 |
+
"nvidia-cuda-runtime-cu11 11.7.99\n",
|
| 229 |
+
"nvidia-cudnn-cu11 8.5.0.96\n",
|
| 230 |
+
"nvidia-cufft-cu11 10.9.0.58\n",
|
| 231 |
+
"nvidia-curand-cu11 10.2.10.91\n",
|
| 232 |
+
"nvidia-cusolver-cu11 11.4.0.1\n",
|
| 233 |
+
"nvidia-cusparse-cu11 11.7.4.91\n",
|
| 234 |
+
"nvidia-nccl-cu11 2.14.3\n",
|
| 235 |
+
"nvidia-nvtx-cu11 11.7.91\n",
|
| 236 |
+
"nvtx 0.2.8\n",
|
| 237 |
+
"oauthlib 3.2.2\n",
|
| 238 |
+
"onnxruntime 1.15.1\n",
|
| 239 |
+
"openai 0.28.0\n",
|
| 240 |
+
"opencv-python 4.8.0.76\n",
|
| 241 |
+
"opencv-python-headless 4.8.0.76\n",
|
| 242 |
+
"opt-einsum 3.3.0\n",
|
| 243 |
+
"optuna 3.3.0\n",
|
| 244 |
+
"outcome 1.2.0\n",
|
| 245 |
+
"overrides 7.4.0\n",
|
| 246 |
+
"packaging 23.1\n",
|
| 247 |
+
"pandas 1.5.3\n",
|
| 248 |
+
"pandocfilters 1.5.0\n",
|
| 249 |
+
"parso 0.8.3\n",
|
| 250 |
+
"partd 1.4.0\n",
|
| 251 |
+
"pathspec 0.11.2\n",
|
| 252 |
+
"pathtools 0.1.2\n",
|
| 253 |
+
"pathy 0.10.2\n",
|
| 254 |
+
"patsy 0.5.3\n",
|
| 255 |
+
"peft 0.6.0.dev0\n",
|
| 256 |
+
"pexpect 4.8.0\n",
|
| 257 |
+
"pickleshare 0.7.5\n",
|
| 258 |
+
"Pillow 10.0.1\n",
|
| 259 |
+
"pinecone-client 2.2.4\n",
|
| 260 |
+
"pip 23.2.1\n",
|
| 261 |
+
"platformdirs 3.10.0\n",
|
| 262 |
+
"plotly 5.17.0\n",
|
| 263 |
+
"pluggy 1.3.0\n",
|
| 264 |
+
"ply 3.11\n",
|
| 265 |
+
"pooch 1.7.0\n",
|
| 266 |
+
"portalocker 2.8.2\n",
|
| 267 |
+
"posthog 3.0.2\n",
|
| 268 |
+
"preshed 3.0.9\n",
|
| 269 |
+
"prometheus-client 0.17.1\n",
|
| 270 |
+
"promise 2.3\n",
|
| 271 |
+
"prompt-toolkit 3.0.39\n",
|
| 272 |
+
"prophet 1.1.4\n",
|
| 273 |
+
"protobuf 4.24.3\n",
|
| 274 |
+
"psutil 5.9.5\n",
|
| 275 |
+
"ptxcompiler-cu11 0.7.0.post1\n",
|
| 276 |
+
"ptyprocess 0.7.0\n",
|
| 277 |
+
"pulsar-client 3.3.0\n",
|
| 278 |
+
"pure-eval 0.2.2\n",
|
| 279 |
+
"py 1.11.0\n",
|
| 280 |
+
"py4j 0.10.9.7\n",
|
| 281 |
+
"pyarrow 11.0.0\n",
|
| 282 |
+
"pyasn1 0.5.0\n",
|
| 283 |
+
"pyasn1-modules 0.3.0\n",
|
| 284 |
+
"pybind11 2.11.1\n",
|
| 285 |
+
"pycparser 2.21\n",
|
| 286 |
+
"pydantic 1.10.12\n",
|
| 287 |
+
"pydantic_core 2.6.3\n",
|
| 288 |
+
"pydicom 2.4.3\n",
|
| 289 |
+
"pyfasttext 0.4.6\n",
|
| 290 |
+
"Pygments 2.16.1\n",
|
| 291 |
+
"pygraphviz 1.11\n",
|
| 292 |
+
"pylibraft-cu11 23.8.0\n",
|
| 293 |
+
"PyMeeus 0.5.12\n",
|
| 294 |
+
"PyMySQL 1.1.0\n",
|
| 295 |
+
"pynvml 11.4.1\n",
|
| 296 |
+
"pyparsing 3.1.1\n",
|
| 297 |
+
"pypdf 3.16.1\n",
|
| 298 |
+
"PyPika 0.48.9\n",
|
| 299 |
+
"pystan 2.19.1.1\n",
|
| 300 |
+
"pytest 6.2.5\n",
|
| 301 |
+
"pytest-asyncio 0.20.3\n",
|
| 302 |
+
"python-dateutil 2.8.2\n",
|
| 303 |
+
"python-dotenv 1.0.0\n",
|
| 304 |
+
"python-json-logger 2.0.7\n",
|
| 305 |
+
"python-slugify 8.0.1\n",
|
| 306 |
+
"pytz 2023.3.post1\n",
|
| 307 |
+
"PyWavelets 1.4.1\n",
|
| 308 |
+
"PyYAML 6.0.1\n",
|
| 309 |
+
"pyzmq 25.1.1\n",
|
| 310 |
+
"qtconsole 5.4.4\n",
|
| 311 |
+
"QtPy 2.4.0\n",
|
| 312 |
+
"qudida 0.0.4\n",
|
| 313 |
+
"raft-dask-cu11 23.8.0\n",
|
| 314 |
+
"referencing 0.30.2\n",
|
| 315 |
+
"regex 2023.8.8\n",
|
| 316 |
+
"requests 2.31.0\n",
|
| 317 |
+
"requests-oauthlib 1.3.1\n",
|
| 318 |
+
"rfc3339-validator 0.1.4\n",
|
| 319 |
+
"rfc3986-validator 0.1.1\n",
|
| 320 |
+
"rmm-cu11 23.8.0\n",
|
| 321 |
+
"rpds-py 0.10.3\n",
|
| 322 |
+
"rsa 4.9\n",
|
| 323 |
+
"safetensors 0.3.3\n",
|
| 324 |
+
"scikit-image 0.21.0\n",
|
| 325 |
+
"scikit-learn 1.3.0\n",
|
| 326 |
+
"scipy 1.11.2\n",
|
| 327 |
+
"seaborn 0.12.2\n",
|
| 328 |
+
"Send2Trash 1.8.2\n",
|
| 329 |
+
"sentencepiece 0.1.99\n",
|
| 330 |
+
"sentry-sdk 1.31.0\n",
|
| 331 |
+
"setproctitle 1.3.2\n",
|
| 332 |
+
"setuptools 68.0.0\n",
|
| 333 |
+
"shap 0.42.1\n",
|
| 334 |
+
"six 1.16.0\n",
|
| 335 |
+
"slicer 0.0.7\n",
|
| 336 |
+
"smart-open 6.4.0\n",
|
| 337 |
+
"smmap 5.0.1\n",
|
| 338 |
+
"sniffio 1.3.0\n",
|
| 339 |
+
"snowballstemmer 2.2.0\n",
|
| 340 |
+
"socksio 1.0.0\n",
|
| 341 |
+
"sortedcontainers 2.4.0\n",
|
| 342 |
+
"soundfile 0.12.1\n",
|
| 343 |
+
"soupsieve 2.5\n",
|
| 344 |
+
"soxr 0.3.6\n",
|
| 345 |
+
"soynlp 0.0.493\n",
|
| 346 |
+
"soyspacing 1.0.17\n",
|
| 347 |
+
"spacy 3.6.1\n",
|
| 348 |
+
"spacy-legacy 3.0.12\n",
|
| 349 |
+
"spacy-loggers 1.0.5\n",
|
| 350 |
+
"Sphinx 7.2.6\n",
|
| 351 |
+
"sphinx-rtd-theme 1.3.0\n",
|
| 352 |
+
"sphinxcontrib-applehelp 1.0.7\n",
|
| 353 |
+
"sphinxcontrib-devhelp 1.0.5\n",
|
| 354 |
+
"sphinxcontrib-htmlhelp 2.0.4\n",
|
| 355 |
+
"sphinxcontrib-jquery 4.1\n",
|
| 356 |
+
"sphinxcontrib-jsmath 1.0.1\n",
|
| 357 |
+
"sphinxcontrib-qthelp 1.0.6\n",
|
| 358 |
+
"sphinxcontrib-serializinghtml 1.1.9\n",
|
| 359 |
+
"SQLAlchemy 2.0.21\n",
|
| 360 |
+
"srsly 2.4.7\n",
|
| 361 |
+
"stack-data 0.6.2\n",
|
| 362 |
+
"starlette 0.27.0\n",
|
| 363 |
+
"statsmodels 0.14.0\n",
|
| 364 |
+
"sympy 1.12\n",
|
| 365 |
+
"tabulate 0.9.0\n",
|
| 366 |
+
"tbb 2021.10.0\n",
|
| 367 |
+
"tblib 2.0.0\n",
|
| 368 |
+
"tenacity 8.2.3\n",
|
| 369 |
+
"tensorboard 2.13.0\n",
|
| 370 |
+
"tensorboard-data-server 0.7.1\n",
|
| 371 |
+
"tensorboardX 2.6.2.2\n",
|
| 372 |
+
"tensorflow 2.13.0\n",
|
| 373 |
+
"tensorflow-datasets 4.9.3\n",
|
| 374 |
+
"tensorflow-estimator 2.13.0\n",
|
| 375 |
+
"tensorflow-io-gcs-filesystem 0.34.0\n",
|
| 376 |
+
"tensorflow-metadata 1.14.0\n",
|
| 377 |
+
"termcolor 2.3.0\n",
|
| 378 |
+
"terminado 0.17.1\n",
|
| 379 |
+
"testpath 0.6.0\n",
|
| 380 |
+
"text-unidecode 1.3\n",
|
| 381 |
+
"thinc 8.1.12\n",
|
| 382 |
+
"threadpoolctl 3.2.0\n",
|
| 383 |
+
"tifffile 2023.9.18\n",
|
| 384 |
+
"tiktoken 0.5.1\n",
|
| 385 |
+
"tinycss2 1.2.1\n",
|
| 386 |
+
"tokenizers 0.14.0\n",
|
| 387 |
+
"toml 0.10.2\n",
|
| 388 |
+
"tomli 2.0.1\n",
|
| 389 |
+
"toolz 0.12.0\n",
|
| 390 |
+
"torch 2.0.0\n",
|
| 391 |
+
"torchaudio 2.0.2+cu118\n",
|
| 392 |
+
"torchdata 0.6.0\n",
|
| 393 |
+
"torchsummary 1.5.1\n",
|
| 394 |
+
"torchtext 0.15.1\n",
|
| 395 |
+
"torchtriton 2.0.0+f16138d447\n",
|
| 396 |
+
"torchvision 0.15.2\n",
|
| 397 |
+
"tornado 6.3.3\n",
|
| 398 |
+
"tqdm 4.66.1\n",
|
| 399 |
+
"traitlets 5.10.0\n",
|
| 400 |
+
"transformers 4.34.0.dev0\n",
|
| 401 |
+
"treelite 3.2.0\n",
|
| 402 |
+
"treelite-runtime 3.2.0\n",
|
| 403 |
+
"trio 0.22.2\n",
|
| 404 |
+
"triton 2.0.0\n",
|
| 405 |
+
"typer 0.9.0\n",
|
| 406 |
+
"typing_extensions 4.5.0\n",
|
| 407 |
+
"typing-inspect 0.9.0\n",
|
| 408 |
+
"tzdata 2023.3\n",
|
| 409 |
+
"ucx-py-cu11 0.33.0\n",
|
| 410 |
+
"uri-template 1.3.0\n",
|
| 411 |
+
"urllib3 1.26.16\n",
|
| 412 |
+
"uvicorn 0.23.2\n",
|
| 413 |
+
"uvloop 0.17.0\n",
|
| 414 |
+
"wandb 0.15.10\n",
|
| 415 |
+
"wasabi 1.1.2\n",
|
| 416 |
+
"watchfiles 0.20.0\n",
|
| 417 |
+
"wcwidth 0.2.6\n",
|
| 418 |
+
"webcolors 1.13\n",
|
| 419 |
+
"webencodings 0.5.1\n",
|
| 420 |
+
"websocket-client 1.6.3\n",
|
| 421 |
+
"websockets 11.0.3\n",
|
| 422 |
+
"Werkzeug 2.3.7\n",
|
| 423 |
+
"wheel 0.38.4\n",
|
| 424 |
+
"widgetsnbextension 4.0.9\n",
|
| 425 |
+
"wordcloud 1.9.2\n",
|
| 426 |
+
"wrapt 1.15.0\n",
|
| 427 |
+
"xgboost 2.0.0\n",
|
| 428 |
+
"xxhash 3.3.0\n",
|
| 429 |
+
"yacs 0.1.8\n",
|
| 430 |
+
"yarl 1.9.2\n",
|
| 431 |
+
"zict 3.0.0\n",
|
| 432 |
+
"zipp 3.17.0\n"
|
| 433 |
+
]
|
| 434 |
+
}
|
| 435 |
+
],
|
| 436 |
+
"source": [
|
| 437 |
+
"!pip list"
|
| 438 |
+
]
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"cell_type": "code",
|
| 442 |
+
"execution_count": 1,
|
| 443 |
+
"metadata": {},
|
| 444 |
+
"outputs": [
|
| 445 |
+
{
|
| 446 |
+
"ename": "ModuleNotFoundError",
|
| 447 |
+
"evalue": "No module named 'numpy'",
|
| 448 |
+
"output_type": "error",
|
| 449 |
+
"traceback": [
|
| 450 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 451 |
+
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
| 452 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnatsort\u001b[39;00m\n",
|
| 453 |
+
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'numpy'"
|
| 454 |
+
]
|
| 455 |
+
}
|
| 456 |
+
],
|
| 457 |
+
"source": [
|
| 458 |
+
"import os\n",
|
| 459 |
+
"import numpy as np\n",
|
| 460 |
+
"import pandas as pd\n",
|
| 461 |
+
"import natsort\n",
|
| 462 |
+
"from datetime import datetime\n",
|
| 463 |
+
"from tqdm.auto import tqdm"
|
| 464 |
+
]
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"cell_type": "code",
|
| 468 |
+
"execution_count": 2,
|
| 469 |
+
"metadata": {},
|
| 470 |
+
"outputs": [],
|
| 471 |
+
"source": [
|
| 472 |
+
"def get_data(year):\n",
|
| 473 |
+
" files = natsort.natsorted(os.listdir(f'../../data/๋๊ธฐ์ง/{year}/'))\n",
|
| 474 |
+
" data = []\n",
|
| 475 |
+
" for file in tqdm(files, desc=f\"Reading files...({len(files)})\"):\n",
|
| 476 |
+
" data.append(pd.read_excel(f'../../data/๋๊ธฐ์ง/{year}/{file}', usecols=[\"์ง์ญ\", '๋ง', \"์ธก์ ์์ฝ๋\", \"์ธก์ ์๋ช
\", \"์ธก์ ์ผ์\", \"O3\", \"NO2\", \"PM10\", \"PM25\", \"์ฃผ์\"]))\n",
|
| 477 |
+
"\n",
|
| 478 |
+
" return pd.concat(data)"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"cell_type": "code",
|
| 483 |
+
"execution_count": 3,
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [],
|
| 486 |
+
"source": [
|
| 487 |
+
"# ํฉ์น ๋ฐ์ดํฐ์ ๋ ์ง ์ ๋ณด๋ฅผ ์ถ๊ฐํ๋ค.\n",
|
| 488 |
+
"def add_date(df):\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" df[\"์ธก์ ์ผ์\"] = df[\"์ธก์ ์ผ์\"].astype(str).str[:10]\n",
|
| 491 |
+
" df[\"์ธก์ ์ผ์\"] = pd.to_datetime(df[\"์ธก์ ์ผ์\"], format='%Y%m%d%H', errors=\"coerce\")\n",
|
| 492 |
+
"\n",
|
| 493 |
+
" df[\"year\"] = df[\"์ธก์ ์ผ์\"].dt.year\n",
|
| 494 |
+
" df[\"month\"] = df[\"์ธก์ ์ผ์\"].dt.month\n",
|
| 495 |
+
" df[\"day\"] = df[\"์ธก์ ์ผ์\"].dt.day\n",
|
| 496 |
+
" df[\"hour\"] = df[\"์ธก์ ์ผ์\"].dt.hour\n",
|
| 497 |
+
"\n",
|
| 498 |
+
" return df"
|
| 499 |
+
]
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"cell_type": "code",
|
| 503 |
+
"execution_count": 4,
|
| 504 |
+
"metadata": {},
|
| 505 |
+
"outputs": [
|
| 506 |
+
{
|
| 507 |
+
"name": "stderr",
|
| 508 |
+
"output_type": "stream",
|
| 509 |
+
"text": [
|
| 510 |
+
" 0%| | 0/6 [00:00<?, ?it/s]\n",
|
| 511 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 512 |
+
"Reading files...(13): 8%|โ | 1/13 [00:34<06:57, 34.80s/it]\u001b[A\n",
|
| 513 |
+
"Reading files...(13): 15%|โโ | 2/13 [01:12<06:41, 36.47s/it]\u001b[A\n",
|
| 514 |
+
"Reading files...(13): 23%|โโโ | 3/13 [01:47<05:58, 35.89s/it]\u001b[A\n",
|
| 515 |
+
"Reading files...(13): 31%|โโโ | 4/13 [02:23<05:23, 35.96s/it]\u001b[A\n",
|
| 516 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [02:59<04:47, 35.92s/it]\u001b[A\n",
|
| 517 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [03:35<04:12, 36.09s/it]\u001b[A\n",
|
| 518 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [04:12<02:16, 27.35s/it]\u001b[A\n",
|
| 519 |
+
"Reading files...(13): 69%|โโโโโโโ | 9/13 [04:46<01:56, 29.05s/it]\u001b[A\n",
|
| 520 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [05:21<01:31, 30.55s/it]\u001b[A\n",
|
| 521 |
+
"Reading files...(13): 85%|โโโโโโโโโ | 11/13 [05:58<01:04, 32.46s/it]\u001b[A\n",
|
| 522 |
+
"Reading files...(13): 92%|โโโโโโโโโโ| 12/13 [06:37<00:34, 34.28s/it]\u001b[A\n",
|
| 523 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [07:08<00:00, 32.93s/it]\u001b[A\n",
|
| 524 |
+
" 17%|โโ | 1/6 [07:18<36:30, 438.18s/it]\n",
|
| 525 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 526 |
+
"Reading files...(13): 8%|โ | 1/13 [00:43<08:41, 43.43s/it]\u001b[A\n",
|
| 527 |
+
"Reading files...(13): 15%|โโ | 2/13 [01:26<07:56, 43.29s/it]\u001b[A\n",
|
| 528 |
+
"Reading files...(13): 23%|โโโ | 3/13 [02:07<07:02, 42.22s/it]\u001b[A\n",
|
| 529 |
+
"Reading files...(13): 31%|โโโ | 4/13 [02:50<06:23, 42.66s/it]\u001b[A\n",
|
| 530 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [03:28<05:27, 40.90s/it]\u001b[A\n",
|
| 531 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [04:15<04:59, 42.79s/it]\u001b[A\n",
|
| 532 |
+
"Reading files...(13): 54%|โโโโโโ | 7/13 [04:58<04:18, 43.14s/it]\u001b[A\n",
|
| 533 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [05:43<03:37, 43.47s/it]\u001b[A\n",
|
| 534 |
+
"Reading files...(13): 69%|โโโโโโโ | 9/13 [06:28<02:55, 43.96s/it]\u001b[A\n",
|
| 535 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [07:12<02:12, 44.01s/it]\u001b[A\n",
|
| 536 |
+
"Reading files...(13): 85%|โโโโโโโโโ | 11/13 [07:52<01:25, 42.90s/it]\u001b[A\n",
|
| 537 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [08:34<00:00, 39.61s/it]\u001b[A\n",
|
| 538 |
+
" 33%|โโโโ | 2/6 [16:05<32:42, 490.55s/it]\n",
|
| 539 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 540 |
+
"Reading files...(13): 8%|โ | 1/13 [00:49<09:56, 49.74s/it]\u001b[A\n",
|
| 541 |
+
"Reading files...(13): 15%|โโ | 2/13 [01:43<09:31, 51.98s/it]\u001b[A\n",
|
| 542 |
+
"Reading files...(13): 23%|โโโ | 3/13 [02:33<08:29, 50.96s/it]\u001b[A\n",
|
| 543 |
+
"Reading files...(13): 31%|โโโ | 4/13 [03:23<07:38, 50.95s/it]\u001b[A\n",
|
| 544 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [04:13<06:43, 50.46s/it]\u001b[A\n",
|
| 545 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [04:58<05:40, 48.71s/it]\u001b[A\n",
|
| 546 |
+
"Reading files...(13): 54%|โโโโโโ | 7/13 [05:50<04:57, 49.66s/it]\u001b[A\n",
|
| 547 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [06:45<04:16, 51.29s/it]\u001b[A\n",
|
| 548 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [07:38<01:58, 39.46s/it]\u001b[A\n",
|
| 549 |
+
"Reading files...(13): 85%|โโโโโโโโโ | 11/13 [08:30<01:25, 42.79s/it]\u001b[A\n",
|
| 550 |
+
"Reading files...(13): 92%|โโโโโโโโโโ| 12/13 [09:26<00:46, 46.32s/it]\u001b[A\n",
|
| 551 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [10:13<00:00, 47.19s/it]\u001b[A\n",
|
| 552 |
+
" 50%|โโโโโ | 3/6 [26:32<27:38, 552.96s/it]\n",
|
| 553 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 554 |
+
"Reading files...(13): 8%|โ | 1/13 [00:59<11:48, 59.01s/it]\u001b[A\n",
|
| 555 |
+
"Reading files...(13): 15%|โโ | 2/13 [01:56<10:40, 58.19s/it]\u001b[A\n",
|
| 556 |
+
"Reading files...(13): 23%|โโโ | 3/13 [02:53<09:37, 57.77s/it]\u001b[A\n",
|
| 557 |
+
"Reading files...(13): 31%|โโโ | 4/13 [03:52<08:41, 58.00s/it]\u001b[A\n",
|
| 558 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [04:44<07:26, 55.77s/it]\u001b[A\n",
|
| 559 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [05:40<06:32, 56.05s/it]\u001b[A\n",
|
| 560 |
+
"Reading files...(13): 54%|โโโโโโ | 7/13 [06:36<05:36, 56.06s/it]\u001b[A\n",
|
| 561 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [07:33<04:42, 56.42s/it]\u001b[A\n",
|
| 562 |
+
"Reading files...(13): 69%|โโโโโโโ | 9/13 [08:34<03:51, 57.76s/it]\u001b[A\n",
|
| 563 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [09:35<02:56, 58.75s/it]\u001b[A\n",
|
| 564 |
+
"Reading files...(13): 92%|โโโโโโโโโโ| 12/13 [10:33<00:44, 44.84s/it]\u001b[A\n",
|
| 565 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [11:32<00:00, 53.29s/it]\u001b[A\n",
|
| 566 |
+
" 67%|โโโโโโโ | 4/6 [38:20<20:28, 614.26s/it]\n",
|
| 567 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 568 |
+
"Reading files...(13): 8%|โ | 1/13 [00:59<11:57, 59.79s/it]\u001b[A\n",
|
| 569 |
+
"Reading files...(13): 15%|โโ | 2/13 [02:01<11:07, 60.67s/it]\u001b[A\n",
|
| 570 |
+
"Reading files...(13): 23%|โโโ | 3/13 [03:02<10:10, 61.02s/it]\u001b[A\n",
|
| 571 |
+
"Reading files...(13): 31%|โโโ | 4/13 [03:57<08:48, 58.74s/it]\u001b[A\n",
|
| 572 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [04:57<07:53, 59.18s/it]\u001b[A\n",
|
| 573 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [06:00<07:03, 60.45s/it]\u001b[A\n",
|
| 574 |
+
"Reading files...(13): 54%|โโโโโโ | 7/13 [07:00<06:02, 60.38s/it]\u001b[A\n",
|
| 575 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [08:02<05:04, 60.85s/it]\u001b[A\n",
|
| 576 |
+
"Reading files...(13): 69%|โโโโโโโ | 9/13 [09:04<04:04, 61.03s/it]\u001b[A\n",
|
| 577 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [10:04<03:02, 60.67s/it]\u001b[A\n",
|
| 578 |
+
"Reading files...(13): 92%|โโโโโโโโโโ| 12/13 [11:06<00:46, 46.76s/it]\u001b[A\n",
|
| 579 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [12:09<00:00, 56.08s/it]\u001b[A\n",
|
| 580 |
+
" 83%|โโโโโโโโโ | 5/6 [50:46<11:01, 661.78s/it]\n",
|
| 581 |
+
"Reading files...(13): 0%| | 0/13 [00:00<?, ?it/s]\u001b[A\n",
|
| 582 |
+
"Reading files...(13): 8%|โ | 1/13 [01:03<12:46, 63.88s/it]\u001b[A\n",
|
| 583 |
+
"Reading files...(13): 15%|โโ | 2/13 [02:08<11:50, 64.56s/it]\u001b[A\n",
|
| 584 |
+
"Reading files...(13): 23%|โโโ | 3/13 [03:10<10:32, 63.22s/it]\u001b[A\n",
|
| 585 |
+
"Reading files...(13): 31%|โโโ | 4/13 [04:07<09:05, 60.63s/it]\u001b[A\n",
|
| 586 |
+
"Reading files...(13): 38%|โโโโ | 5/13 [05:09<08:11, 61.41s/it]\u001b[A\n",
|
| 587 |
+
"Reading files...(13): 46%|โโโโโ | 6/13 [06:12<07:13, 61.92s/it]\u001b[A\n",
|
| 588 |
+
"Reading files...(13): 54%|โโโโโโ | 7/13 [07:13<06:09, 61.50s/it]\u001b[A\n",
|
| 589 |
+
"Reading files...(13): 62%|โโโโโโโ | 8/13 [08:15<05:08, 61.64s/it]\u001b[A\n",
|
| 590 |
+
"Reading files...(13): 69%|โโโโโโโ | 9/13 [09:17<04:07, 61.81s/it]\u001b[A\n",
|
| 591 |
+
"Reading files...(13): 77%|โโโโโโโโ | 10/13 [10:19<03:05, 61.96s/it]\u001b[A\n",
|
| 592 |
+
"Reading files...(13): 92%|โโโโโโโโโโ| 12/13 [11:23<00:47, 47.75s/it]\u001b[A\n",
|
| 593 |
+
"Reading files...(13): 100%|โโโโโโโโโโ| 13/13 [12:27<00:00, 57.50s/it]\u001b[A\n",
|
| 594 |
+
"100%|โโโโโโโโโโ| 6/6 [1:03:31<00:00, 635.28s/it]\n"
|
| 595 |
+
]
|
| 596 |
+
}
|
| 597 |
+
],
|
| 598 |
+
"source": [
|
| 599 |
+
"import os\n",
|
| 600 |
+
"import pandas as pd\n",
|
| 601 |
+
"from tqdm.auto import tqdm\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"# ๋๊ธฐ์ง ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์์ ํ๋์ ํ์ผ๋ก ํฉ์น๋ค.\n",
|
| 604 |
+
"def get_data(year):\n",
|
| 605 |
+
" directory = f'../../data/๋๊ธฐ์ง/{year}/'\n",
|
| 606 |
+
" files = os.listdir(directory)\n",
|
| 607 |
+
" data = []\n",
|
| 608 |
+
" \n",
|
| 609 |
+
" # ํ์ผ ๋ชฉ๋ก์์ ๋๋ ํ ๋ฆฌ๋ฅผ ์ ์ธํ๊ณ ์ค์ง Excel ํ์ผ๋ง ์ฒ๋ฆฌ\n",
|
| 610 |
+
" for file in tqdm(files, desc=f\"Reading files...({len(files)})\"):\n",
|
| 611 |
+
" file_path = os.path.join(directory, file)\n",
|
| 612 |
+
" if os.path.isfile(file_path) and file_path.endswith(('.xls', '.xlsx')): # Excel ํ์ผ ํ์ฅ์๋ง ํ์ฉ\n",
|
| 613 |
+
" data.append(pd.read_excel(file_path, usecols=[\"์ง์ญ\", '๋ง', \"์ธก์ ์์ฝ๋\", \"์ธก์ ์๋ช
\", \"์ธก์ ์ผ์\", \"O3\", \"NO2\", \"PM10\", \"PM25\", \"์ฃผ์\"]))\n",
|
| 614 |
+
" \n",
|
| 615 |
+
" return pd.concat(data)\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"years = [2018, 2019, 2020,2021,2022,2023] # 2018๋
๋ถํฐ 2023๋
๊น์ง์ ๋ฐ์ดํฐ๋ฅผ ํฉ์น๋ค.\n",
|
| 618 |
+
"for year in tqdm(years):\n",
|
| 619 |
+
" data = get_data(year)\n",
|
| 620 |
+
" data = add_date(data)\n",
|
| 621 |
+
" data.reset_index(drop=True, inplace=True)\n",
|
| 622 |
+
" data.to_feather(f\"../../data/๋๊ธฐ์ง/{year}.feather\")\n"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"cell_type": "code",
|
| 627 |
+
"execution_count": 6,
|
| 628 |
+
"metadata": {},
|
| 629 |
+
"outputs": [
|
| 630 |
+
{
|
| 631 |
+
"data": {
|
| 632 |
+
"text/html": [
|
| 633 |
+
"<div>\n",
|
| 634 |
+
"<style scoped>\n",
|
| 635 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 636 |
+
" vertical-align: middle;\n",
|
| 637 |
+
" }\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" .dataframe tbody tr th {\n",
|
| 640 |
+
" vertical-align: top;\n",
|
| 641 |
+
" }\n",
|
| 642 |
+
"\n",
|
| 643 |
+
" .dataframe thead th {\n",
|
| 644 |
+
" text-align: right;\n",
|
| 645 |
+
" }\n",
|
| 646 |
+
"</style>\n",
|
| 647 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 648 |
+
" <thead>\n",
|
| 649 |
+
" <tr style=\"text-align: right;\">\n",
|
| 650 |
+
" <th></th>\n",
|
| 651 |
+
" <th>์ง์ญ</th>\n",
|
| 652 |
+
" <th>๋ง</th>\n",
|
| 653 |
+
" <th>์ธก์ ์์ฝ๋</th>\n",
|
| 654 |
+
" <th>์ธก์ ์๋ช
</th>\n",
|
| 655 |
+
" <th>์ธก์ ์ผ์</th>\n",
|
| 656 |
+
" <th>O3</th>\n",
|
| 657 |
+
" <th>NO2</th>\n",
|
| 658 |
+
" <th>PM10</th>\n",
|
| 659 |
+
" <th>PM25</th>\n",
|
| 660 |
+
" <th>์ฃผ์</th>\n",
|
| 661 |
+
" <th>year</th>\n",
|
| 662 |
+
" <th>month</th>\n",
|
| 663 |
+
" <th>day</th>\n",
|
| 664 |
+
" <th>hour</th>\n",
|
| 665 |
+
" </tr>\n",
|
| 666 |
+
" </thead>\n",
|
| 667 |
+
" <tbody>\n",
|
| 668 |
+
" <tr>\n",
|
| 669 |
+
" <th>0</th>\n",
|
| 670 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 671 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 672 |
+
" <td>111121</td>\n",
|
| 673 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 674 |
+
" <td>2023-07-01 01:00:00</td>\n",
|
| 675 |
+
" <td>0.0249</td>\n",
|
| 676 |
+
" <td>0.0188</td>\n",
|
| 677 |
+
" <td>21.0</td>\n",
|
| 678 |
+
" <td>19.0</td>\n",
|
| 679 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 680 |
+
" <td>2023.0</td>\n",
|
| 681 |
+
" <td>7.0</td>\n",
|
| 682 |
+
" <td>1.0</td>\n",
|
| 683 |
+
" <td>1.0</td>\n",
|
| 684 |
+
" </tr>\n",
|
| 685 |
+
" <tr>\n",
|
| 686 |
+
" <th>1</th>\n",
|
| 687 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 688 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 689 |
+
" <td>111121</td>\n",
|
| 690 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 691 |
+
" <td>2023-07-01 02:00:00</td>\n",
|
| 692 |
+
" <td>0.0263</td>\n",
|
| 693 |
+
" <td>0.0163</td>\n",
|
| 694 |
+
" <td>18.0</td>\n",
|
| 695 |
+
" <td>15.0</td>\n",
|
| 696 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 697 |
+
" <td>2023.0</td>\n",
|
| 698 |
+
" <td>7.0</td>\n",
|
| 699 |
+
" <td>1.0</td>\n",
|
| 700 |
+
" <td>2.0</td>\n",
|
| 701 |
+
" </tr>\n",
|
| 702 |
+
" <tr>\n",
|
| 703 |
+
" <th>2</th>\n",
|
| 704 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 705 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 706 |
+
" <td>111121</td>\n",
|
| 707 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 708 |
+
" <td>2023-07-01 03:00:00</td>\n",
|
| 709 |
+
" <td>0.0218</td>\n",
|
| 710 |
+
" <td>0.0192</td>\n",
|
| 711 |
+
" <td>24.0</td>\n",
|
| 712 |
+
" <td>21.0</td>\n",
|
| 713 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 714 |
+
" <td>2023.0</td>\n",
|
| 715 |
+
" <td>7.0</td>\n",
|
| 716 |
+
" <td>1.0</td>\n",
|
| 717 |
+
" <td>3.0</td>\n",
|
| 718 |
+
" </tr>\n",
|
| 719 |
+
" <tr>\n",
|
| 720 |
+
" <th>3</th>\n",
|
| 721 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 722 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 723 |
+
" <td>111121</td>\n",
|
| 724 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 725 |
+
" <td>2023-07-01 04:00:00</td>\n",
|
| 726 |
+
" <td>0.0131</td>\n",
|
| 727 |
+
" <td>0.0214</td>\n",
|
| 728 |
+
" <td>25.0</td>\n",
|
| 729 |
+
" <td>19.0</td>\n",
|
| 730 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 731 |
+
" <td>2023.0</td>\n",
|
| 732 |
+
" <td>7.0</td>\n",
|
| 733 |
+
" <td>1.0</td>\n",
|
| 734 |
+
" <td>4.0</td>\n",
|
| 735 |
+
" </tr>\n",
|
| 736 |
+
" <tr>\n",
|
| 737 |
+
" <th>4</th>\n",
|
| 738 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 739 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 740 |
+
" <td>111121</td>\n",
|
| 741 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 742 |
+
" <td>2023-07-01 05:00:00</td>\n",
|
| 743 |
+
" <td>0.0131</td>\n",
|
| 744 |
+
" <td>0.0160</td>\n",
|
| 745 |
+
" <td>25.0</td>\n",
|
| 746 |
+
" <td>21.0</td>\n",
|
| 747 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 748 |
+
" <td>2023.0</td>\n",
|
| 749 |
+
" <td>7.0</td>\n",
|
| 750 |
+
" <td>1.0</td>\n",
|
| 751 |
+
" <td>5.0</td>\n",
|
| 752 |
+
" </tr>\n",
|
| 753 |
+
" <tr>\n",
|
| 754 |
+
" <th>5</th>\n",
|
| 755 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 756 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 757 |
+
" <td>111121</td>\n",
|
| 758 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 759 |
+
" <td>2023-07-01 06:00:00</td>\n",
|
| 760 |
+
" <td>0.0115</td>\n",
|
| 761 |
+
" <td>0.0196</td>\n",
|
| 762 |
+
" <td>23.0</td>\n",
|
| 763 |
+
" <td>18.0</td>\n",
|
| 764 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 765 |
+
" <td>2023.0</td>\n",
|
| 766 |
+
" <td>7.0</td>\n",
|
| 767 |
+
" <td>1.0</td>\n",
|
| 768 |
+
" <td>6.0</td>\n",
|
| 769 |
+
" </tr>\n",
|
| 770 |
+
" <tr>\n",
|
| 771 |
+
" <th>6</th>\n",
|
| 772 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 773 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 774 |
+
" <td>111121</td>\n",
|
| 775 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 776 |
+
" <td>2023-07-01 07:00:00</td>\n",
|
| 777 |
+
" <td>0.0094</td>\n",
|
| 778 |
+
" <td>0.0230</td>\n",
|
| 779 |
+
" <td>26.0</td>\n",
|
| 780 |
+
" <td>21.0</td>\n",
|
| 781 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 782 |
+
" <td>2023.0</td>\n",
|
| 783 |
+
" <td>7.0</td>\n",
|
| 784 |
+
" <td>1.0</td>\n",
|
| 785 |
+
" <td>7.0</td>\n",
|
| 786 |
+
" </tr>\n",
|
| 787 |
+
" <tr>\n",
|
| 788 |
+
" <th>7</th>\n",
|
| 789 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 790 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 791 |
+
" <td>111121</td>\n",
|
| 792 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 793 |
+
" <td>2023-07-01 08:00:00</td>\n",
|
| 794 |
+
" <td>0.0222</td>\n",
|
| 795 |
+
" <td>0.0175</td>\n",
|
| 796 |
+
" <td>26.0</td>\n",
|
| 797 |
+
" <td>20.0</td>\n",
|
| 798 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 799 |
+
" <td>2023.0</td>\n",
|
| 800 |
+
" <td>7.0</td>\n",
|
| 801 |
+
" <td>1.0</td>\n",
|
| 802 |
+
" <td>8.0</td>\n",
|
| 803 |
+
" </tr>\n",
|
| 804 |
+
" <tr>\n",
|
| 805 |
+
" <th>8</th>\n",
|
| 806 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 807 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 808 |
+
" <td>111121</td>\n",
|
| 809 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 810 |
+
" <td>2023-07-01 09:00:00</td>\n",
|
| 811 |
+
" <td>0.0396</td>\n",
|
| 812 |
+
" <td>0.0153</td>\n",
|
| 813 |
+
" <td>27.0</td>\n",
|
| 814 |
+
" <td>20.0</td>\n",
|
| 815 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 816 |
+
" <td>2023.0</td>\n",
|
| 817 |
+
" <td>7.0</td>\n",
|
| 818 |
+
" <td>1.0</td>\n",
|
| 819 |
+
" <td>9.0</td>\n",
|
| 820 |
+
" </tr>\n",
|
| 821 |
+
" <tr>\n",
|
| 822 |
+
" <th>9</th>\n",
|
| 823 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 824 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 825 |
+
" <td>111121</td>\n",
|
| 826 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 827 |
+
" <td>2023-07-01 10:00:00</td>\n",
|
| 828 |
+
" <td>0.0530</td>\n",
|
| 829 |
+
" <td>0.0105</td>\n",
|
| 830 |
+
" <td>19.0</td>\n",
|
| 831 |
+
" <td>16.0</td>\n",
|
| 832 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 833 |
+
" <td>2023.0</td>\n",
|
| 834 |
+
" <td>7.0</td>\n",
|
| 835 |
+
" <td>1.0</td>\n",
|
| 836 |
+
" <td>10.0</td>\n",
|
| 837 |
+
" </tr>\n",
|
| 838 |
+
" <tr>\n",
|
| 839 |
+
" <th>10</th>\n",
|
| 840 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 841 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 842 |
+
" <td>111121</td>\n",
|
| 843 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 844 |
+
" <td>2023-07-01 11:00:00</td>\n",
|
| 845 |
+
" <td>0.0607</td>\n",
|
| 846 |
+
" <td>0.0090</td>\n",
|
| 847 |
+
" <td>20.0</td>\n",
|
| 848 |
+
" <td>20.0</td>\n",
|
| 849 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 850 |
+
" <td>2023.0</td>\n",
|
| 851 |
+
" <td>7.0</td>\n",
|
| 852 |
+
" <td>1.0</td>\n",
|
| 853 |
+
" <td>11.0</td>\n",
|
| 854 |
+
" </tr>\n",
|
| 855 |
+
" <tr>\n",
|
| 856 |
+
" <th>11</th>\n",
|
| 857 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 858 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 859 |
+
" <td>111121</td>\n",
|
| 860 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 861 |
+
" <td>2023-07-01 12:00:00</td>\n",
|
| 862 |
+
" <td>0.0688</td>\n",
|
| 863 |
+
" <td>0.0114</td>\n",
|
| 864 |
+
" <td>20.0</td>\n",
|
| 865 |
+
" <td>17.0</td>\n",
|
| 866 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 867 |
+
" <td>2023.0</td>\n",
|
| 868 |
+
" <td>7.0</td>\n",
|
| 869 |
+
" <td>1.0</td>\n",
|
| 870 |
+
" <td>12.0</td>\n",
|
| 871 |
+
" </tr>\n",
|
| 872 |
+
" <tr>\n",
|
| 873 |
+
" <th>12</th>\n",
|
| 874 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 875 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 876 |
+
" <td>111121</td>\n",
|
| 877 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 878 |
+
" <td>2023-07-01 13:00:00</td>\n",
|
| 879 |
+
" <td>0.0758</td>\n",
|
| 880 |
+
" <td>0.0101</td>\n",
|
| 881 |
+
" <td>23.0</td>\n",
|
| 882 |
+
" <td>17.0</td>\n",
|
| 883 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 884 |
+
" <td>2023.0</td>\n",
|
| 885 |
+
" <td>7.0</td>\n",
|
| 886 |
+
" <td>1.0</td>\n",
|
| 887 |
+
" <td>13.0</td>\n",
|
| 888 |
+
" </tr>\n",
|
| 889 |
+
" <tr>\n",
|
| 890 |
+
" <th>13</th>\n",
|
| 891 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 892 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 893 |
+
" <td>111121</td>\n",
|
| 894 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 895 |
+
" <td>2023-07-01 14:00:00</td>\n",
|
| 896 |
+
" <td>0.0743</td>\n",
|
| 897 |
+
" <td>0.0093</td>\n",
|
| 898 |
+
" <td>20.0</td>\n",
|
| 899 |
+
" <td>17.0</td>\n",
|
| 900 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 901 |
+
" <td>2023.0</td>\n",
|
| 902 |
+
" <td>7.0</td>\n",
|
| 903 |
+
" <td>1.0</td>\n",
|
| 904 |
+
" <td>14.0</td>\n",
|
| 905 |
+
" </tr>\n",
|
| 906 |
+
" <tr>\n",
|
| 907 |
+
" <th>14</th>\n",
|
| 908 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 909 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 910 |
+
" <td>111121</td>\n",
|
| 911 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 912 |
+
" <td>2023-07-01 15:00:00</td>\n",
|
| 913 |
+
" <td>0.0749</td>\n",
|
| 914 |
+
" <td>0.0100</td>\n",
|
| 915 |
+
" <td>19.0</td>\n",
|
| 916 |
+
" <td>11.0</td>\n",
|
| 917 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 918 |
+
" <td>2023.0</td>\n",
|
| 919 |
+
" <td>7.0</td>\n",
|
| 920 |
+
" <td>1.0</td>\n",
|
| 921 |
+
" <td>15.0</td>\n",
|
| 922 |
+
" </tr>\n",
|
| 923 |
+
" <tr>\n",
|
| 924 |
+
" <th>15</th>\n",
|
| 925 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 926 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 927 |
+
" <td>111121</td>\n",
|
| 928 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 929 |
+
" <td>2023-07-01 16:00:00</td>\n",
|
| 930 |
+
" <td>0.0716</td>\n",
|
| 931 |
+
" <td>0.0092</td>\n",
|
| 932 |
+
" <td>19.0</td>\n",
|
| 933 |
+
" <td>15.0</td>\n",
|
| 934 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 935 |
+
" <td>2023.0</td>\n",
|
| 936 |
+
" <td>7.0</td>\n",
|
| 937 |
+
" <td>1.0</td>\n",
|
| 938 |
+
" <td>16.0</td>\n",
|
| 939 |
+
" </tr>\n",
|
| 940 |
+
" <tr>\n",
|
| 941 |
+
" <th>16</th>\n",
|
| 942 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 943 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 944 |
+
" <td>111121</td>\n",
|
| 945 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 946 |
+
" <td>2023-07-01 17:00:00</td>\n",
|
| 947 |
+
" <td>0.0613</td>\n",
|
| 948 |
+
" <td>0.0099</td>\n",
|
| 949 |
+
" <td>18.0</td>\n",
|
| 950 |
+
" <td>15.0</td>\n",
|
| 951 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 952 |
+
" <td>2023.0</td>\n",
|
| 953 |
+
" <td>7.0</td>\n",
|
| 954 |
+
" <td>1.0</td>\n",
|
| 955 |
+
" <td>17.0</td>\n",
|
| 956 |
+
" </tr>\n",
|
| 957 |
+
" <tr>\n",
|
| 958 |
+
" <th>17</th>\n",
|
| 959 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 960 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 961 |
+
" <td>111121</td>\n",
|
| 962 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 963 |
+
" <td>2023-07-01 18:00:00</td>\n",
|
| 964 |
+
" <td>0.0496</td>\n",
|
| 965 |
+
" <td>0.0098</td>\n",
|
| 966 |
+
" <td>18.0</td>\n",
|
| 967 |
+
" <td>14.0</td>\n",
|
| 968 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 969 |
+
" <td>2023.0</td>\n",
|
| 970 |
+
" <td>7.0</td>\n",
|
| 971 |
+
" <td>1.0</td>\n",
|
| 972 |
+
" <td>18.0</td>\n",
|
| 973 |
+
" </tr>\n",
|
| 974 |
+
" <tr>\n",
|
| 975 |
+
" <th>18</th>\n",
|
| 976 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 977 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 978 |
+
" <td>111121</td>\n",
|
| 979 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 980 |
+
" <td>2023-07-01 19:00:00</td>\n",
|
| 981 |
+
" <td>0.0473</td>\n",
|
| 982 |
+
" <td>0.0124</td>\n",
|
| 983 |
+
" <td>17.0</td>\n",
|
| 984 |
+
" <td>17.0</td>\n",
|
| 985 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 986 |
+
" <td>2023.0</td>\n",
|
| 987 |
+
" <td>7.0</td>\n",
|
| 988 |
+
" <td>1.0</td>\n",
|
| 989 |
+
" <td>19.0</td>\n",
|
| 990 |
+
" </tr>\n",
|
| 991 |
+
" <tr>\n",
|
| 992 |
+
" <th>19</th>\n",
|
| 993 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 994 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 995 |
+
" <td>111121</td>\n",
|
| 996 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 997 |
+
" <td>2023-07-01 20:00:00</td>\n",
|
| 998 |
+
" <td>0.0498</td>\n",
|
| 999 |
+
" <td>0.0170</td>\n",
|
| 1000 |
+
" <td>17.0</td>\n",
|
| 1001 |
+
" <td>15.0</td>\n",
|
| 1002 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1003 |
+
" <td>2023.0</td>\n",
|
| 1004 |
+
" <td>7.0</td>\n",
|
| 1005 |
+
" <td>1.0</td>\n",
|
| 1006 |
+
" <td>20.0</td>\n",
|
| 1007 |
+
" </tr>\n",
|
| 1008 |
+
" <tr>\n",
|
| 1009 |
+
" <th>20</th>\n",
|
| 1010 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1011 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1012 |
+
" <td>111121</td>\n",
|
| 1013 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1014 |
+
" <td>2023-07-01 21:00:00</td>\n",
|
| 1015 |
+
" <td>0.0616</td>\n",
|
| 1016 |
+
" <td>0.0134</td>\n",
|
| 1017 |
+
" <td>23.0</td>\n",
|
| 1018 |
+
" <td>20.0</td>\n",
|
| 1019 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1020 |
+
" <td>2023.0</td>\n",
|
| 1021 |
+
" <td>7.0</td>\n",
|
| 1022 |
+
" <td>1.0</td>\n",
|
| 1023 |
+
" <td>21.0</td>\n",
|
| 1024 |
+
" </tr>\n",
|
| 1025 |
+
" <tr>\n",
|
| 1026 |
+
" <th>21</th>\n",
|
| 1027 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1028 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1029 |
+
" <td>111121</td>\n",
|
| 1030 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1031 |
+
" <td>2023-07-01 22:00:00</td>\n",
|
| 1032 |
+
" <td>0.0543</td>\n",
|
| 1033 |
+
" <td>0.0109</td>\n",
|
| 1034 |
+
" <td>18.0</td>\n",
|
| 1035 |
+
" <td>16.0</td>\n",
|
| 1036 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1037 |
+
" <td>2023.0</td>\n",
|
| 1038 |
+
" <td>7.0</td>\n",
|
| 1039 |
+
" <td>1.0</td>\n",
|
| 1040 |
+
" <td>22.0</td>\n",
|
| 1041 |
+
" </tr>\n",
|
| 1042 |
+
" <tr>\n",
|
| 1043 |
+
" <th>22</th>\n",
|
| 1044 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1045 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1046 |
+
" <td>111121</td>\n",
|
| 1047 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1048 |
+
" <td>2023-07-01 23:00:00</td>\n",
|
| 1049 |
+
" <td>0.0507</td>\n",
|
| 1050 |
+
" <td>0.0113</td>\n",
|
| 1051 |
+
" <td>17.0</td>\n",
|
| 1052 |
+
" <td>16.0</td>\n",
|
| 1053 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1054 |
+
" <td>2023.0</td>\n",
|
| 1055 |
+
" <td>7.0</td>\n",
|
| 1056 |
+
" <td>1.0</td>\n",
|
| 1057 |
+
" <td>23.0</td>\n",
|
| 1058 |
+
" </tr>\n",
|
| 1059 |
+
" <tr>\n",
|
| 1060 |
+
" <th>23</th>\n",
|
| 1061 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1062 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1063 |
+
" <td>111121</td>\n",
|
| 1064 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1065 |
+
" <td>NaT</td>\n",
|
| 1066 |
+
" <td>0.0427</td>\n",
|
| 1067 |
+
" <td>0.0125</td>\n",
|
| 1068 |
+
" <td>17.0</td>\n",
|
| 1069 |
+
" <td>16.0</td>\n",
|
| 1070 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1071 |
+
" <td>NaN</td>\n",
|
| 1072 |
+
" <td>NaN</td>\n",
|
| 1073 |
+
" <td>NaN</td>\n",
|
| 1074 |
+
" <td>NaN</td>\n",
|
| 1075 |
+
" </tr>\n",
|
| 1076 |
+
" <tr>\n",
|
| 1077 |
+
" <th>24</th>\n",
|
| 1078 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1079 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1080 |
+
" <td>111121</td>\n",
|
| 1081 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1082 |
+
" <td>2023-07-02 01:00:00</td>\n",
|
| 1083 |
+
" <td>0.0334</td>\n",
|
| 1084 |
+
" <td>0.0148</td>\n",
|
| 1085 |
+
" <td>21.0</td>\n",
|
| 1086 |
+
" <td>20.0</td>\n",
|
| 1087 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1088 |
+
" <td>2023.0</td>\n",
|
| 1089 |
+
" <td>7.0</td>\n",
|
| 1090 |
+
" <td>2.0</td>\n",
|
| 1091 |
+
" <td>1.0</td>\n",
|
| 1092 |
+
" </tr>\n",
|
| 1093 |
+
" <tr>\n",
|
| 1094 |
+
" <th>25</th>\n",
|
| 1095 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1096 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1097 |
+
" <td>111121</td>\n",
|
| 1098 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1099 |
+
" <td>2023-07-02 02:00:00</td>\n",
|
| 1100 |
+
" <td>0.0337</td>\n",
|
| 1101 |
+
" <td>0.0133</td>\n",
|
| 1102 |
+
" <td>22.0</td>\n",
|
| 1103 |
+
" <td>18.0</td>\n",
|
| 1104 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1105 |
+
" <td>2023.0</td>\n",
|
| 1106 |
+
" <td>7.0</td>\n",
|
| 1107 |
+
" <td>2.0</td>\n",
|
| 1108 |
+
" <td>2.0</td>\n",
|
| 1109 |
+
" </tr>\n",
|
| 1110 |
+
" <tr>\n",
|
| 1111 |
+
" <th>26</th>\n",
|
| 1112 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1113 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1114 |
+
" <td>111121</td>\n",
|
| 1115 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1116 |
+
" <td>2023-07-02 03:00:00</td>\n",
|
| 1117 |
+
" <td>0.0260</td>\n",
|
| 1118 |
+
" <td>0.0162</td>\n",
|
| 1119 |
+
" <td>25.0</td>\n",
|
| 1120 |
+
" <td>20.0</td>\n",
|
| 1121 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1122 |
+
" <td>2023.0</td>\n",
|
| 1123 |
+
" <td>7.0</td>\n",
|
| 1124 |
+
" <td>2.0</td>\n",
|
| 1125 |
+
" <td>3.0</td>\n",
|
| 1126 |
+
" </tr>\n",
|
| 1127 |
+
" <tr>\n",
|
| 1128 |
+
" <th>27</th>\n",
|
| 1129 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1130 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1131 |
+
" <td>111121</td>\n",
|
| 1132 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1133 |
+
" <td>2023-07-02 04:00:00</td>\n",
|
| 1134 |
+
" <td>0.0195</td>\n",
|
| 1135 |
+
" <td>0.0179</td>\n",
|
| 1136 |
+
" <td>22.0</td>\n",
|
| 1137 |
+
" <td>18.0</td>\n",
|
| 1138 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1139 |
+
" <td>2023.0</td>\n",
|
| 1140 |
+
" <td>7.0</td>\n",
|
| 1141 |
+
" <td>2.0</td>\n",
|
| 1142 |
+
" <td>4.0</td>\n",
|
| 1143 |
+
" </tr>\n",
|
| 1144 |
+
" <tr>\n",
|
| 1145 |
+
" <th>28</th>\n",
|
| 1146 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1147 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1148 |
+
" <td>111121</td>\n",
|
| 1149 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1150 |
+
" <td>2023-07-02 05:00:00</td>\n",
|
| 1151 |
+
" <td>0.0171</td>\n",
|
| 1152 |
+
" <td>0.0170</td>\n",
|
| 1153 |
+
" <td>19.0</td>\n",
|
| 1154 |
+
" <td>17.0</td>\n",
|
| 1155 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1156 |
+
" <td>2023.0</td>\n",
|
| 1157 |
+
" <td>7.0</td>\n",
|
| 1158 |
+
" <td>2.0</td>\n",
|
| 1159 |
+
" <td>5.0</td>\n",
|
| 1160 |
+
" </tr>\n",
|
| 1161 |
+
" <tr>\n",
|
| 1162 |
+
" <th>29</th>\n",
|
| 1163 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1164 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1165 |
+
" <td>111121</td>\n",
|
| 1166 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1167 |
+
" <td>2023-07-02 06:00:00</td>\n",
|
| 1168 |
+
" <td>0.0181</td>\n",
|
| 1169 |
+
" <td>0.0145</td>\n",
|
| 1170 |
+
" <td>14.0</td>\n",
|
| 1171 |
+
" <td>10.0</td>\n",
|
| 1172 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1173 |
+
" <td>2023.0</td>\n",
|
| 1174 |
+
" <td>7.0</td>\n",
|
| 1175 |
+
" <td>2.0</td>\n",
|
| 1176 |
+
" <td>6.0</td>\n",
|
| 1177 |
+
" </tr>\n",
|
| 1178 |
+
" <tr>\n",
|
| 1179 |
+
" <th>30</th>\n",
|
| 1180 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1181 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1182 |
+
" <td>111121</td>\n",
|
| 1183 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1184 |
+
" <td>2023-07-02 07:00:00</td>\n",
|
| 1185 |
+
" <td>0.0174</td>\n",
|
| 1186 |
+
" <td>0.0156</td>\n",
|
| 1187 |
+
" <td>11.0</td>\n",
|
| 1188 |
+
" <td>10.0</td>\n",
|
| 1189 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1190 |
+
" <td>2023.0</td>\n",
|
| 1191 |
+
" <td>7.0</td>\n",
|
| 1192 |
+
" <td>2.0</td>\n",
|
| 1193 |
+
" <td>7.0</td>\n",
|
| 1194 |
+
" </tr>\n",
|
| 1195 |
+
" <tr>\n",
|
| 1196 |
+
" <th>31</th>\n",
|
| 1197 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1198 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1199 |
+
" <td>111121</td>\n",
|
| 1200 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1201 |
+
" <td>2023-07-02 08:00:00</td>\n",
|
| 1202 |
+
" <td>0.0213</td>\n",
|
| 1203 |
+
" <td>0.0147</td>\n",
|
| 1204 |
+
" <td>12.0</td>\n",
|
| 1205 |
+
" <td>9.0</td>\n",
|
| 1206 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1207 |
+
" <td>2023.0</td>\n",
|
| 1208 |
+
" <td>7.0</td>\n",
|
| 1209 |
+
" <td>2.0</td>\n",
|
| 1210 |
+
" <td>8.0</td>\n",
|
| 1211 |
+
" </tr>\n",
|
| 1212 |
+
" <tr>\n",
|
| 1213 |
+
" <th>32</th>\n",
|
| 1214 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1215 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1216 |
+
" <td>111121</td>\n",
|
| 1217 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1218 |
+
" <td>2023-07-02 09:00:00</td>\n",
|
| 1219 |
+
" <td>0.0267</td>\n",
|
| 1220 |
+
" <td>0.0143</td>\n",
|
| 1221 |
+
" <td>11.0</td>\n",
|
| 1222 |
+
" <td>10.0</td>\n",
|
| 1223 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1224 |
+
" <td>2023.0</td>\n",
|
| 1225 |
+
" <td>7.0</td>\n",
|
| 1226 |
+
" <td>2.0</td>\n",
|
| 1227 |
+
" <td>9.0</td>\n",
|
| 1228 |
+
" </tr>\n",
|
| 1229 |
+
" <tr>\n",
|
| 1230 |
+
" <th>33</th>\n",
|
| 1231 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1232 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1233 |
+
" <td>111121</td>\n",
|
| 1234 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1235 |
+
" <td>2023-07-02 10:00:00</td>\n",
|
| 1236 |
+
" <td>0.0289</td>\n",
|
| 1237 |
+
" <td>0.0155</td>\n",
|
| 1238 |
+
" <td>12.0</td>\n",
|
| 1239 |
+
" <td>9.0</td>\n",
|
| 1240 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1241 |
+
" <td>2023.0</td>\n",
|
| 1242 |
+
" <td>7.0</td>\n",
|
| 1243 |
+
" <td>2.0</td>\n",
|
| 1244 |
+
" <td>10.0</td>\n",
|
| 1245 |
+
" </tr>\n",
|
| 1246 |
+
" <tr>\n",
|
| 1247 |
+
" <th>34</th>\n",
|
| 1248 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1249 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1250 |
+
" <td>111121</td>\n",
|
| 1251 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1252 |
+
" <td>2023-07-02 11:00:00</td>\n",
|
| 1253 |
+
" <td>0.0381</td>\n",
|
| 1254 |
+
" <td>0.0108</td>\n",
|
| 1255 |
+
" <td>13.0</td>\n",
|
| 1256 |
+
" <td>13.0</td>\n",
|
| 1257 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1258 |
+
" <td>2023.0</td>\n",
|
| 1259 |
+
" <td>7.0</td>\n",
|
| 1260 |
+
" <td>2.0</td>\n",
|
| 1261 |
+
" <td>11.0</td>\n",
|
| 1262 |
+
" </tr>\n",
|
| 1263 |
+
" <tr>\n",
|
| 1264 |
+
" <th>35</th>\n",
|
| 1265 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1266 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1267 |
+
" <td>111121</td>\n",
|
| 1268 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1269 |
+
" <td>2023-07-02 12:00:00</td>\n",
|
| 1270 |
+
" <td>0.0441</td>\n",
|
| 1271 |
+
" <td>0.0079</td>\n",
|
| 1272 |
+
" <td>13.0</td>\n",
|
| 1273 |
+
" <td>12.0</td>\n",
|
| 1274 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1275 |
+
" <td>2023.0</td>\n",
|
| 1276 |
+
" <td>7.0</td>\n",
|
| 1277 |
+
" <td>2.0</td>\n",
|
| 1278 |
+
" <td>12.0</td>\n",
|
| 1279 |
+
" </tr>\n",
|
| 1280 |
+
" <tr>\n",
|
| 1281 |
+
" <th>36</th>\n",
|
| 1282 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1283 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1284 |
+
" <td>111121</td>\n",
|
| 1285 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1286 |
+
" <td>2023-07-02 13:00:00</td>\n",
|
| 1287 |
+
" <td>0.0489</td>\n",
|
| 1288 |
+
" <td>0.0067</td>\n",
|
| 1289 |
+
" <td>8.0</td>\n",
|
| 1290 |
+
" <td>10.0</td>\n",
|
| 1291 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1292 |
+
" <td>2023.0</td>\n",
|
| 1293 |
+
" <td>7.0</td>\n",
|
| 1294 |
+
" <td>2.0</td>\n",
|
| 1295 |
+
" <td>13.0</td>\n",
|
| 1296 |
+
" </tr>\n",
|
| 1297 |
+
" <tr>\n",
|
| 1298 |
+
" <th>37</th>\n",
|
| 1299 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1300 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1301 |
+
" <td>111121</td>\n",
|
| 1302 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1303 |
+
" <td>2023-07-02 14:00:00</td>\n",
|
| 1304 |
+
" <td>0.0498</td>\n",
|
| 1305 |
+
" <td>0.0072</td>\n",
|
| 1306 |
+
" <td>11.0</td>\n",
|
| 1307 |
+
" <td>10.0</td>\n",
|
| 1308 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1309 |
+
" <td>2023.0</td>\n",
|
| 1310 |
+
" <td>7.0</td>\n",
|
| 1311 |
+
" <td>2.0</td>\n",
|
| 1312 |
+
" <td>14.0</td>\n",
|
| 1313 |
+
" </tr>\n",
|
| 1314 |
+
" <tr>\n",
|
| 1315 |
+
" <th>38</th>\n",
|
| 1316 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1317 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1318 |
+
" <td>111121</td>\n",
|
| 1319 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1320 |
+
" <td>2023-07-02 15:00:00</td>\n",
|
| 1321 |
+
" <td>0.0459</td>\n",
|
| 1322 |
+
" <td>0.0073</td>\n",
|
| 1323 |
+
" <td>14.0</td>\n",
|
| 1324 |
+
" <td>12.0</td>\n",
|
| 1325 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1326 |
+
" <td>2023.0</td>\n",
|
| 1327 |
+
" <td>7.0</td>\n",
|
| 1328 |
+
" <td>2.0</td>\n",
|
| 1329 |
+
" <td>15.0</td>\n",
|
| 1330 |
+
" </tr>\n",
|
| 1331 |
+
" <tr>\n",
|
| 1332 |
+
" <th>39</th>\n",
|
| 1333 |
+
" <td>์์ธ ์ค๊ตฌ</td>\n",
|
| 1334 |
+
" <td>๋์๋๊ธฐ</td>\n",
|
| 1335 |
+
" <td>111121</td>\n",
|
| 1336 |
+
" <td>์ค๊ตฌ</td>\n",
|
| 1337 |
+
" <td>2023-07-02 16:00:00</td>\n",
|
| 1338 |
+
" <td>0.0474</td>\n",
|
| 1339 |
+
" <td>0.0079</td>\n",
|
| 1340 |
+
" <td>12.0</td>\n",
|
| 1341 |
+
" <td>11.0</td>\n",
|
| 1342 |
+
" <td>์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15</td>\n",
|
| 1343 |
+
" <td>2023.0</td>\n",
|
| 1344 |
+
" <td>7.0</td>\n",
|
| 1345 |
+
" <td>2.0</td>\n",
|
| 1346 |
+
" <td>16.0</td>\n",
|
| 1347 |
+
" </tr>\n",
|
| 1348 |
+
" </tbody>\n",
|
| 1349 |
+
"</table>\n",
|
| 1350 |
+
"</div>"
|
| 1351 |
+
],
|
| 1352 |
+
"text/plain": [
|
| 1353 |
+
" ์ง์ญ ๋ง ์ธก์ ์์ฝ๋ ์ธก์ ์๋ช
์ธก์ ์ผ์ O3 NO2 PM10 PM25 \\\n",
|
| 1354 |
+
"0 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 01:00:00 0.0249 0.0188 21.0 19.0 \n",
|
| 1355 |
+
"1 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 02:00:00 0.0263 0.0163 18.0 15.0 \n",
|
| 1356 |
+
"2 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 03:00:00 0.0218 0.0192 24.0 21.0 \n",
|
| 1357 |
+
"3 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 04:00:00 0.0131 0.0214 25.0 19.0 \n",
|
| 1358 |
+
"4 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 05:00:00 0.0131 0.0160 25.0 21.0 \n",
|
| 1359 |
+
"5 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 06:00:00 0.0115 0.0196 23.0 18.0 \n",
|
| 1360 |
+
"6 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 07:00:00 0.0094 0.0230 26.0 21.0 \n",
|
| 1361 |
+
"7 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 08:00:00 0.0222 0.0175 26.0 20.0 \n",
|
| 1362 |
+
"8 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 09:00:00 0.0396 0.0153 27.0 20.0 \n",
|
| 1363 |
+
"9 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 10:00:00 0.0530 0.0105 19.0 16.0 \n",
|
| 1364 |
+
"10 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 11:00:00 0.0607 0.0090 20.0 20.0 \n",
|
| 1365 |
+
"11 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 12:00:00 0.0688 0.0114 20.0 17.0 \n",
|
| 1366 |
+
"12 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 13:00:00 0.0758 0.0101 23.0 17.0 \n",
|
| 1367 |
+
"13 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 14:00:00 0.0743 0.0093 20.0 17.0 \n",
|
| 1368 |
+
"14 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 15:00:00 0.0749 0.0100 19.0 11.0 \n",
|
| 1369 |
+
"15 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 16:00:00 0.0716 0.0092 19.0 15.0 \n",
|
| 1370 |
+
"16 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 17:00:00 0.0613 0.0099 18.0 15.0 \n",
|
| 1371 |
+
"17 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 18:00:00 0.0496 0.0098 18.0 14.0 \n",
|
| 1372 |
+
"18 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 19:00:00 0.0473 0.0124 17.0 17.0 \n",
|
| 1373 |
+
"19 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 20:00:00 0.0498 0.0170 17.0 15.0 \n",
|
| 1374 |
+
"20 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 21:00:00 0.0616 0.0134 23.0 20.0 \n",
|
| 1375 |
+
"21 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 22:00:00 0.0543 0.0109 18.0 16.0 \n",
|
| 1376 |
+
"22 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-01 23:00:00 0.0507 0.0113 17.0 16.0 \n",
|
| 1377 |
+
"23 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ NaT 0.0427 0.0125 17.0 16.0 \n",
|
| 1378 |
+
"24 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 01:00:00 0.0334 0.0148 21.0 20.0 \n",
|
| 1379 |
+
"25 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 02:00:00 0.0337 0.0133 22.0 18.0 \n",
|
| 1380 |
+
"26 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 03:00:00 0.0260 0.0162 25.0 20.0 \n",
|
| 1381 |
+
"27 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 04:00:00 0.0195 0.0179 22.0 18.0 \n",
|
| 1382 |
+
"28 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 05:00:00 0.0171 0.0170 19.0 17.0 \n",
|
| 1383 |
+
"29 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 06:00:00 0.0181 0.0145 14.0 10.0 \n",
|
| 1384 |
+
"30 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 07:00:00 0.0174 0.0156 11.0 10.0 \n",
|
| 1385 |
+
"31 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 08:00:00 0.0213 0.0147 12.0 9.0 \n",
|
| 1386 |
+
"32 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 09:00:00 0.0267 0.0143 11.0 10.0 \n",
|
| 1387 |
+
"33 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 10:00:00 0.0289 0.0155 12.0 9.0 \n",
|
| 1388 |
+
"34 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 11:00:00 0.0381 0.0108 13.0 13.0 \n",
|
| 1389 |
+
"35 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 12:00:00 0.0441 0.0079 13.0 12.0 \n",
|
| 1390 |
+
"36 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 13:00:00 0.0489 0.0067 8.0 10.0 \n",
|
| 1391 |
+
"37 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 14:00:00 0.0498 0.0072 11.0 10.0 \n",
|
| 1392 |
+
"38 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 15:00:00 0.0459 0.0073 14.0 12.0 \n",
|
| 1393 |
+
"39 ์์ธ ์ค๊ตฌ ๋์๋๊ธฐ 111121 ์ค๊ตฌ 2023-07-02 16:00:00 0.0474 0.0079 12.0 11.0 \n",
|
| 1394 |
+
"\n",
|
| 1395 |
+
" ์ฃผ์ year month day hour \n",
|
| 1396 |
+
"0 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 1.0 \n",
|
| 1397 |
+
"1 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 2.0 \n",
|
| 1398 |
+
"2 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 3.0 \n",
|
| 1399 |
+
"3 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 4.0 \n",
|
| 1400 |
+
"4 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 5.0 \n",
|
| 1401 |
+
"5 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 6.0 \n",
|
| 1402 |
+
"6 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 7.0 \n",
|
| 1403 |
+
"7 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 8.0 \n",
|
| 1404 |
+
"8 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 9.0 \n",
|
| 1405 |
+
"9 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 10.0 \n",
|
| 1406 |
+
"10 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 11.0 \n",
|
| 1407 |
+
"11 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 12.0 \n",
|
| 1408 |
+
"12 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 13.0 \n",
|
| 1409 |
+
"13 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 14.0 \n",
|
| 1410 |
+
"14 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 15.0 \n",
|
| 1411 |
+
"15 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 16.0 \n",
|
| 1412 |
+
"16 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 17.0 \n",
|
| 1413 |
+
"17 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 18.0 \n",
|
| 1414 |
+
"18 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 19.0 \n",
|
| 1415 |
+
"19 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 20.0 \n",
|
| 1416 |
+
"20 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 21.0 \n",
|
| 1417 |
+
"21 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 22.0 \n",
|
| 1418 |
+
"22 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 1.0 23.0 \n",
|
| 1419 |
+
"23 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 NaN NaN NaN NaN \n",
|
| 1420 |
+
"24 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 1.0 \n",
|
| 1421 |
+
"25 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 2.0 \n",
|
| 1422 |
+
"26 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 3.0 \n",
|
| 1423 |
+
"27 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 4.0 \n",
|
| 1424 |
+
"28 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 5.0 \n",
|
| 1425 |
+
"29 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 6.0 \n",
|
| 1426 |
+
"30 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 7.0 \n",
|
| 1427 |
+
"31 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 8.0 \n",
|
| 1428 |
+
"32 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 9.0 \n",
|
| 1429 |
+
"33 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 10.0 \n",
|
| 1430 |
+
"34 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 11.0 \n",
|
| 1431 |
+
"35 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 12.0 \n",
|
| 1432 |
+
"36 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 13.0 \n",
|
| 1433 |
+
"37 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 14.0 \n",
|
| 1434 |
+
"38 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 15.0 \n",
|
| 1435 |
+
"39 ์์ธ ์ค๊ตฌ ๋์๊ถ๊ธธ 15 2023.0 7.0 2.0 16.0 "
|
| 1436 |
+
]
|
| 1437 |
+
},
|
| 1438 |
+
"execution_count": 6,
|
| 1439 |
+
"metadata": {},
|
| 1440 |
+
"output_type": "execute_result"
|
| 1441 |
+
}
|
| 1442 |
+
],
|
| 1443 |
+
"source": [
|
| 1444 |
+
"data.head(40)"
|
| 1445 |
+
]
|
| 1446 |
+
}
|
| 1447 |
+
],
|
| 1448 |
+
"metadata": {
|
| 1449 |
+
"kernelspec": {
|
| 1450 |
+
"display_name": "py39",
|
| 1451 |
+
"language": "python",
|
| 1452 |
+
"name": "python3"
|
| 1453 |
+
},
|
| 1454 |
+
"language_info": {
|
| 1455 |
+
"codemirror_mode": {
|
| 1456 |
+
"name": "ipython",
|
| 1457 |
+
"version": 3
|
| 1458 |
+
},
|
| 1459 |
+
"file_extension": ".py",
|
| 1460 |
+
"mimetype": "text/x-python",
|
| 1461 |
+
"name": "python",
|
| 1462 |
+
"nbconvert_exporter": "python",
|
| 1463 |
+
"pygments_lexer": "ipython3",
|
| 1464 |
+
"version": "3.9.18"
|
| 1465 |
+
}
|
| 1466 |
+
},
|
| 1467 |
+
"nbformat": 4,
|
| 1468 |
+
"nbformat_minor": 4
|
| 1469 |
+
}
|
Analysis_code/1.data_preprocessing/1.data_merge.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analysis_code/1.data_preprocessing/3.make_train_test.ipynb
ADDED
|
@@ -0,0 +1,1099 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import pandas as pd\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import matplotlib.pyplot as plt\n",
|
| 12 |
+
"import seaborn as sns\n",
|
| 13 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 14 |
+
"from collections import Counter"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 2,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"df_seoul = pd.read_feather(\"../../data/data_for_modeling/df_seoul.feather\")\n",
|
| 24 |
+
"df_busan = pd.read_feather(\"../../data/data_for_modeling/df_busan.feather\")\n",
|
| 25 |
+
"df_incheon = pd.read_feather(\"../../data/data_for_modeling/df_incheon.feather\")\n",
|
| 26 |
+
"df_daegu = pd.read_feather(\"../../data/data_for_modeling/df_daegu.feather\")\n",
|
| 27 |
+
"df_daejeon = pd.read_feather(\"../../data/data_for_modeling/df_daejeon.feather\")\n",
|
| 28 |
+
"df_gwangju = pd.read_feather(\"../../data/data_for_modeling/df_gwangju.feather\")"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": 3,
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [
|
| 36 |
+
{
|
| 37 |
+
"data": {
|
| 38 |
+
"text/plain": [
|
| 39 |
+
"Counter({2: 48534, 1: 3941, 0: 109})"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
"execution_count": 3,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"output_type": "execute_result"
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"Counter(df_seoul['multi_class'])"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 4,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [
|
| 56 |
+
{
|
| 57 |
+
"data": {
|
| 58 |
+
"text/plain": [
|
| 59 |
+
"Counter({2: 50069, 1: 2350, 0: 165})"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
"execution_count": 4,
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"output_type": "execute_result"
|
| 65 |
+
}
|
| 66 |
+
],
|
| 67 |
+
"source": [
|
| 68 |
+
"Counter(df_busan['multi_class'])"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"execution_count": 5,
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [
|
| 76 |
+
{
|
| 77 |
+
"data": {
|
| 78 |
+
"text/plain": [
|
| 79 |
+
"Counter({2: 44944, 1: 6658, 0: 982})"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
"execution_count": 5,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"output_type": "execute_result"
|
| 85 |
+
}
|
| 86 |
+
],
|
| 87 |
+
"source": [
|
| 88 |
+
"Counter(df_incheon['multi_class'])"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": 6,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [
|
| 96 |
+
{
|
| 97 |
+
"data": {
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"Counter({2: 50919, 1: 1610, 0: 55})"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"execution_count": 6,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"output_type": "execute_result"
|
| 105 |
+
}
|
| 106 |
+
],
|
| 107 |
+
"source": [
|
| 108 |
+
"Counter(df_daegu['multi_class'])"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": 7,
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [
|
| 116 |
+
{
|
| 117 |
+
"data": {
|
| 118 |
+
"text/plain": [
|
| 119 |
+
"Counter({2: 48047, 1: 4227, 0: 310})"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
"execution_count": 7,
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"output_type": "execute_result"
|
| 125 |
+
}
|
| 126 |
+
],
|
| 127 |
+
"source": [
|
| 128 |
+
"Counter(df_daejeon['multi_class'])"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": 8,
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"outputs": [
|
| 136 |
+
{
|
| 137 |
+
"data": {
|
| 138 |
+
"text/plain": [
|
| 139 |
+
"Counter({2: 48405, 1: 4015, 0: 164})"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
"execution_count": 8,
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"output_type": "execute_result"
|
| 145 |
+
}
|
| 146 |
+
],
|
| 147 |
+
"source": [
|
| 148 |
+
"Counter(df_gwangju['multi_class'])"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "code",
|
| 153 |
+
"execution_count": 9,
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"outputs": [
|
| 156 |
+
{
|
| 157 |
+
"data": {
|
| 158 |
+
"text/plain": [
|
| 159 |
+
"(52584, 30)"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
"execution_count": 9,
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"output_type": "execute_result"
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"source": [
|
| 168 |
+
"df_seoul.shape"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": 10,
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"outputs": [],
|
| 176 |
+
"source": [
|
| 177 |
+
"df_seoul = df_seoul.loc[df_seoul['year'].isin([2018, 2019, 2020, 2021]),:].copy()\n",
|
| 178 |
+
"df_busan = df_busan.loc[df_busan['year'].isin([2018, 2019, 2020, 2021]),:].copy()\n",
|
| 179 |
+
"df_incheon = df_incheon.loc[df_incheon['year'].isin([2018, 2019, 2020, 2021]),:].copy()\n",
|
| 180 |
+
"df_daegu = df_daegu.loc[df_daegu['year'].isin([2018, 2019, 2020, 2021]),:].copy()\n",
|
| 181 |
+
"df_daejeon = df_daejeon.loc[df_daejeon['year'].isin([2018, 2019, 2020, 2021]),:].copy()\n",
|
| 182 |
+
"df_gwangju = df_gwangju.loc[df_gwangju['year'].isin([2018, 2019, 2020, 2021]),:].copy()"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": 11,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"cols = [col for col in df_seoul.columns if col != \"multi_class\"] + [\"multi_class\"]"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"cell_type": "code",
|
| 196 |
+
"execution_count": 12,
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"outputs": [],
|
| 199 |
+
"source": [
|
| 200 |
+
"df_seoul = df_seoul[cols]\n",
|
| 201 |
+
"df_busan = df_busan[cols]\n",
|
| 202 |
+
"df_incheon = df_incheon[cols]\n",
|
| 203 |
+
"df_daegu = df_daegu[cols]\n",
|
| 204 |
+
"df_daejeon = df_daejeon[cols]\n",
|
| 205 |
+
"df_gwangju = df_gwangju[cols]"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": 13,
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [],
|
| 213 |
+
"source": [
|
| 214 |
+
"df_seoul_train = df_seoul.loc[df_seoul['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 215 |
+
"df_seoul_test = df_seoul.loc[df_seoul['year'].isin([2021]),:].copy()\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"df_busan_train = df_busan.loc[df_busan['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 218 |
+
"df_busan_test = df_busan.loc[df_busan['year'].isin([2021]),:].copy()\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"df_incheon_train = df_incheon.loc[df_incheon['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 221 |
+
"df_incheon_test = df_incheon.loc[df_incheon['year'].isin([2021]),:].copy()\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"df_daegu_train = df_daegu.loc[df_daegu['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 224 |
+
"df_daegu_test = df_daegu.loc[df_daegu['year'].isin([2021]),:].copy()\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"df_daejeon_train = df_daejeon.loc[df_daejeon['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 227 |
+
"df_daejeon_test = df_daejeon.loc[df_daejeon['year'].isin([2021]),:].copy()\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"df_gwangju_train = df_gwangju.loc[df_gwangju['year'].isin([2018, 2019, 2020]),:].copy()\n",
|
| 230 |
+
"df_gwangju_test = df_gwangju.loc[df_gwangju['year'].isin([2021]),:].copy()"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"execution_count": 14,
|
| 236 |
+
"metadata": {},
|
| 237 |
+
"outputs": [
|
| 238 |
+
{
|
| 239 |
+
"data": {
|
| 240 |
+
"text/html": [
|
| 241 |
+
"<div>\n",
|
| 242 |
+
"<style scoped>\n",
|
| 243 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 244 |
+
" vertical-align: middle;\n",
|
| 245 |
+
" }\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" .dataframe tbody tr th {\n",
|
| 248 |
+
" vertical-align: top;\n",
|
| 249 |
+
" }\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" .dataframe thead th {\n",
|
| 252 |
+
" text-align: right;\n",
|
| 253 |
+
" }\n",
|
| 254 |
+
"</style>\n",
|
| 255 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 256 |
+
" <thead>\n",
|
| 257 |
+
" <tr style=\"text-align: right;\">\n",
|
| 258 |
+
" <th></th>\n",
|
| 259 |
+
" <th>temp_C</th>\n",
|
| 260 |
+
" <th>precip_mm</th>\n",
|
| 261 |
+
" <th>wind_speed</th>\n",
|
| 262 |
+
" <th>wind_dir</th>\n",
|
| 263 |
+
" <th>hm</th>\n",
|
| 264 |
+
" <th>vap_pressure</th>\n",
|
| 265 |
+
" <th>dewpoint_C</th>\n",
|
| 266 |
+
" <th>loc_pressure</th>\n",
|
| 267 |
+
" <th>sea_pressure</th>\n",
|
| 268 |
+
" <th>solarRad</th>\n",
|
| 269 |
+
" <th>...</th>\n",
|
| 270 |
+
" <th>year</th>\n",
|
| 271 |
+
" <th>month</th>\n",
|
| 272 |
+
" <th>hour</th>\n",
|
| 273 |
+
" <th>ground_temp - temp_C</th>\n",
|
| 274 |
+
" <th>hour_sin</th>\n",
|
| 275 |
+
" <th>hour_cos</th>\n",
|
| 276 |
+
" <th>month_sin</th>\n",
|
| 277 |
+
" <th>month_cos</th>\n",
|
| 278 |
+
" <th>visi</th>\n",
|
| 279 |
+
" <th>multi_class</th>\n",
|
| 280 |
+
" </tr>\n",
|
| 281 |
+
" </thead>\n",
|
| 282 |
+
" <tbody>\n",
|
| 283 |
+
" <tr>\n",
|
| 284 |
+
" <th>0</th>\n",
|
| 285 |
+
" <td>1.2</td>\n",
|
| 286 |
+
" <td>0.0</td>\n",
|
| 287 |
+
" <td>1.6</td>\n",
|
| 288 |
+
" <td>360</td>\n",
|
| 289 |
+
" <td>35.0</td>\n",
|
| 290 |
+
" <td>2.3</td>\n",
|
| 291 |
+
" <td>-12.6</td>\n",
|
| 292 |
+
" <td>1015.8</td>\n",
|
| 293 |
+
" <td>1024.6</td>\n",
|
| 294 |
+
" <td>0.00</td>\n",
|
| 295 |
+
" <td>...</td>\n",
|
| 296 |
+
" <td>2018</td>\n",
|
| 297 |
+
" <td>1</td>\n",
|
| 298 |
+
" <td>0</td>\n",
|
| 299 |
+
" <td>-5.4</td>\n",
|
| 300 |
+
" <td>0.000000</td>\n",
|
| 301 |
+
" <td>1.000000e+00</td>\n",
|
| 302 |
+
" <td>0.5</td>\n",
|
| 303 |
+
" <td>0.866025</td>\n",
|
| 304 |
+
" <td>2000.0</td>\n",
|
| 305 |
+
" <td>2</td>\n",
|
| 306 |
+
" </tr>\n",
|
| 307 |
+
" <tr>\n",
|
| 308 |
+
" <th>1</th>\n",
|
| 309 |
+
" <td>0.5</td>\n",
|
| 310 |
+
" <td>0.0</td>\n",
|
| 311 |
+
" <td>1.3</td>\n",
|
| 312 |
+
" <td>360</td>\n",
|
| 313 |
+
" <td>33.0</td>\n",
|
| 314 |
+
" <td>2.1</td>\n",
|
| 315 |
+
" <td>-13.9</td>\n",
|
| 316 |
+
" <td>1015.5</td>\n",
|
| 317 |
+
" <td>1024.3</td>\n",
|
| 318 |
+
" <td>0.00</td>\n",
|
| 319 |
+
" <td>...</td>\n",
|
| 320 |
+
" <td>2018</td>\n",
|
| 321 |
+
" <td>1</td>\n",
|
| 322 |
+
" <td>1</td>\n",
|
| 323 |
+
" <td>-5.4</td>\n",
|
| 324 |
+
" <td>0.258819</td>\n",
|
| 325 |
+
" <td>9.659258e-01</td>\n",
|
| 326 |
+
" <td>0.5</td>\n",
|
| 327 |
+
" <td>0.866025</td>\n",
|
| 328 |
+
" <td>2000.0</td>\n",
|
| 329 |
+
" <td>2</td>\n",
|
| 330 |
+
" </tr>\n",
|
| 331 |
+
" <tr>\n",
|
| 332 |
+
" <th>2</th>\n",
|
| 333 |
+
" <td>0.1</td>\n",
|
| 334 |
+
" <td>0.0</td>\n",
|
| 335 |
+
" <td>1.5</td>\n",
|
| 336 |
+
" <td>20</td>\n",
|
| 337 |
+
" <td>34.0</td>\n",
|
| 338 |
+
" <td>2.1</td>\n",
|
| 339 |
+
" <td>-13.9</td>\n",
|
| 340 |
+
" <td>1015.7</td>\n",
|
| 341 |
+
" <td>1024.5</td>\n",
|
| 342 |
+
" <td>0.00</td>\n",
|
| 343 |
+
" <td>...</td>\n",
|
| 344 |
+
" <td>2018</td>\n",
|
| 345 |
+
" <td>1</td>\n",
|
| 346 |
+
" <td>2</td>\n",
|
| 347 |
+
" <td>-5.4</td>\n",
|
| 348 |
+
" <td>0.500000</td>\n",
|
| 349 |
+
" <td>8.660254e-01</td>\n",
|
| 350 |
+
" <td>0.5</td>\n",
|
| 351 |
+
" <td>0.866025</td>\n",
|
| 352 |
+
" <td>2000.0</td>\n",
|
| 353 |
+
" <td>2</td>\n",
|
| 354 |
+
" </tr>\n",
|
| 355 |
+
" <tr>\n",
|
| 356 |
+
" <th>3</th>\n",
|
| 357 |
+
" <td>0.0</td>\n",
|
| 358 |
+
" <td>0.0</td>\n",
|
| 359 |
+
" <td>2.1</td>\n",
|
| 360 |
+
" <td>320</td>\n",
|
| 361 |
+
" <td>37.0</td>\n",
|
| 362 |
+
" <td>2.3</td>\n",
|
| 363 |
+
" <td>-12.9</td>\n",
|
| 364 |
+
" <td>1015.9</td>\n",
|
| 365 |
+
" <td>1024.7</td>\n",
|
| 366 |
+
" <td>0.00</td>\n",
|
| 367 |
+
" <td>...</td>\n",
|
| 368 |
+
" <td>2018</td>\n",
|
| 369 |
+
" <td>1</td>\n",
|
| 370 |
+
" <td>3</td>\n",
|
| 371 |
+
" <td>-5.0</td>\n",
|
| 372 |
+
" <td>0.707107</td>\n",
|
| 373 |
+
" <td>7.071068e-01</td>\n",
|
| 374 |
+
" <td>0.5</td>\n",
|
| 375 |
+
" <td>0.866025</td>\n",
|
| 376 |
+
" <td>2000.0</td>\n",
|
| 377 |
+
" <td>2</td>\n",
|
| 378 |
+
" </tr>\n",
|
| 379 |
+
" <tr>\n",
|
| 380 |
+
" <th>4</th>\n",
|
| 381 |
+
" <td>-0.1</td>\n",
|
| 382 |
+
" <td>0.0</td>\n",
|
| 383 |
+
" <td>2.3</td>\n",
|
| 384 |
+
" <td>340</td>\n",
|
| 385 |
+
" <td>42.0</td>\n",
|
| 386 |
+
" <td>2.5</td>\n",
|
| 387 |
+
" <td>-11.5</td>\n",
|
| 388 |
+
" <td>1016.0</td>\n",
|
| 389 |
+
" <td>1024.9</td>\n",
|
| 390 |
+
" <td>0.00</td>\n",
|
| 391 |
+
" <td>...</td>\n",
|
| 392 |
+
" <td>2018</td>\n",
|
| 393 |
+
" <td>1</td>\n",
|
| 394 |
+
" <td>4</td>\n",
|
| 395 |
+
" <td>-4.3</td>\n",
|
| 396 |
+
" <td>0.866025</td>\n",
|
| 397 |
+
" <td>5.000000e-01</td>\n",
|
| 398 |
+
" <td>0.5</td>\n",
|
| 399 |
+
" <td>0.866025</td>\n",
|
| 400 |
+
" <td>2000.0</td>\n",
|
| 401 |
+
" <td>2</td>\n",
|
| 402 |
+
" </tr>\n",
|
| 403 |
+
" <tr>\n",
|
| 404 |
+
" <th>5</th>\n",
|
| 405 |
+
" <td>-0.1</td>\n",
|
| 406 |
+
" <td>0.0</td>\n",
|
| 407 |
+
" <td>2.8</td>\n",
|
| 408 |
+
" <td>50</td>\n",
|
| 409 |
+
" <td>43.0</td>\n",
|
| 410 |
+
" <td>2.6</td>\n",
|
| 411 |
+
" <td>-11.2</td>\n",
|
| 412 |
+
" <td>1016.0</td>\n",
|
| 413 |
+
" <td>1024.9</td>\n",
|
| 414 |
+
" <td>0.00</td>\n",
|
| 415 |
+
" <td>...</td>\n",
|
| 416 |
+
" <td>2018</td>\n",
|
| 417 |
+
" <td>1</td>\n",
|
| 418 |
+
" <td>5</td>\n",
|
| 419 |
+
" <td>-4.0</td>\n",
|
| 420 |
+
" <td>0.965926</td>\n",
|
| 421 |
+
" <td>2.588190e-01</td>\n",
|
| 422 |
+
" <td>0.5</td>\n",
|
| 423 |
+
" <td>0.866025</td>\n",
|
| 424 |
+
" <td>2000.0</td>\n",
|
| 425 |
+
" <td>2</td>\n",
|
| 426 |
+
" </tr>\n",
|
| 427 |
+
" <tr>\n",
|
| 428 |
+
" <th>6</th>\n",
|
| 429 |
+
" <td>-0.5</td>\n",
|
| 430 |
+
" <td>0.0</td>\n",
|
| 431 |
+
" <td>2.1</td>\n",
|
| 432 |
+
" <td>20</td>\n",
|
| 433 |
+
" <td>45.0</td>\n",
|
| 434 |
+
" <td>2.6</td>\n",
|
| 435 |
+
" <td>-11.0</td>\n",
|
| 436 |
+
" <td>1016.5</td>\n",
|
| 437 |
+
" <td>1025.4</td>\n",
|
| 438 |
+
" <td>0.00</td>\n",
|
| 439 |
+
" <td>...</td>\n",
|
| 440 |
+
" <td>2018</td>\n",
|
| 441 |
+
" <td>1</td>\n",
|
| 442 |
+
" <td>6</td>\n",
|
| 443 |
+
" <td>-4.1</td>\n",
|
| 444 |
+
" <td>1.000000</td>\n",
|
| 445 |
+
" <td>6.123234e-17</td>\n",
|
| 446 |
+
" <td>0.5</td>\n",
|
| 447 |
+
" <td>0.866025</td>\n",
|
| 448 |
+
" <td>2000.0</td>\n",
|
| 449 |
+
" <td>2</td>\n",
|
| 450 |
+
" </tr>\n",
|
| 451 |
+
" <tr>\n",
|
| 452 |
+
" <th>7</th>\n",
|
| 453 |
+
" <td>-0.8</td>\n",
|
| 454 |
+
" <td>0.0</td>\n",
|
| 455 |
+
" <td>2.5</td>\n",
|
| 456 |
+
" <td>340</td>\n",
|
| 457 |
+
" <td>45.0</td>\n",
|
| 458 |
+
" <td>2.6</td>\n",
|
| 459 |
+
" <td>-11.2</td>\n",
|
| 460 |
+
" <td>1017.1</td>\n",
|
| 461 |
+
" <td>1026.0</td>\n",
|
| 462 |
+
" <td>0.00</td>\n",
|
| 463 |
+
" <td>...</td>\n",
|
| 464 |
+
" <td>2018</td>\n",
|
| 465 |
+
" <td>1</td>\n",
|
| 466 |
+
" <td>7</td>\n",
|
| 467 |
+
" <td>-4.5</td>\n",
|
| 468 |
+
" <td>0.965926</td>\n",
|
| 469 |
+
" <td>-2.588190e-01</td>\n",
|
| 470 |
+
" <td>0.5</td>\n",
|
| 471 |
+
" <td>0.866025</td>\n",
|
| 472 |
+
" <td>2000.0</td>\n",
|
| 473 |
+
" <td>2</td>\n",
|
| 474 |
+
" </tr>\n",
|
| 475 |
+
" <tr>\n",
|
| 476 |
+
" <th>8</th>\n",
|
| 477 |
+
" <td>-0.5</td>\n",
|
| 478 |
+
" <td>0.0</td>\n",
|
| 479 |
+
" <td>1.2</td>\n",
|
| 480 |
+
" <td>360</td>\n",
|
| 481 |
+
" <td>43.0</td>\n",
|
| 482 |
+
" <td>2.5</td>\n",
|
| 483 |
+
" <td>-11.5</td>\n",
|
| 484 |
+
" <td>1017.4</td>\n",
|
| 485 |
+
" <td>1026.3</td>\n",
|
| 486 |
+
" <td>0.03</td>\n",
|
| 487 |
+
" <td>...</td>\n",
|
| 488 |
+
" <td>2018</td>\n",
|
| 489 |
+
" <td>1</td>\n",
|
| 490 |
+
" <td>8</td>\n",
|
| 491 |
+
" <td>-4.0</td>\n",
|
| 492 |
+
" <td>0.866025</td>\n",
|
| 493 |
+
" <td>-5.000000e-01</td>\n",
|
| 494 |
+
" <td>0.5</td>\n",
|
| 495 |
+
" <td>0.866025</td>\n",
|
| 496 |
+
" <td>2000.0</td>\n",
|
| 497 |
+
" <td>2</td>\n",
|
| 498 |
+
" </tr>\n",
|
| 499 |
+
" <tr>\n",
|
| 500 |
+
" <th>9</th>\n",
|
| 501 |
+
" <td>1.7</td>\n",
|
| 502 |
+
" <td>0.0</td>\n",
|
| 503 |
+
" <td>2.1</td>\n",
|
| 504 |
+
" <td>20</td>\n",
|
| 505 |
+
" <td>39.0</td>\n",
|
| 506 |
+
" <td>2.7</td>\n",
|
| 507 |
+
" <td>-10.8</td>\n",
|
| 508 |
+
" <td>1018.1</td>\n",
|
| 509 |
+
" <td>1026.9</td>\n",
|
| 510 |
+
" <td>0.46</td>\n",
|
| 511 |
+
" <td>...</td>\n",
|
| 512 |
+
" <td>2018</td>\n",
|
| 513 |
+
" <td>1</td>\n",
|
| 514 |
+
" <td>9</td>\n",
|
| 515 |
+
" <td>2.8</td>\n",
|
| 516 |
+
" <td>0.707107</td>\n",
|
| 517 |
+
" <td>-7.071068e-01</td>\n",
|
| 518 |
+
" <td>0.5</td>\n",
|
| 519 |
+
" <td>0.866025</td>\n",
|
| 520 |
+
" <td>1953.0</td>\n",
|
| 521 |
+
" <td>2</td>\n",
|
| 522 |
+
" </tr>\n",
|
| 523 |
+
" </tbody>\n",
|
| 524 |
+
"</table>\n",
|
| 525 |
+
"<p>10 rows ร 30 columns</p>\n",
|
| 526 |
+
"</div>"
|
| 527 |
+
],
|
| 528 |
+
"text/plain": [
|
| 529 |
+
" temp_C precip_mm wind_speed wind_dir hm vap_pressure dewpoint_C \\\n",
|
| 530 |
+
"0 1.2 0.0 1.6 360 35.0 2.3 -12.6 \n",
|
| 531 |
+
"1 0.5 0.0 1.3 360 33.0 2.1 -13.9 \n",
|
| 532 |
+
"2 0.1 0.0 1.5 20 34.0 2.1 -13.9 \n",
|
| 533 |
+
"3 0.0 0.0 2.1 320 37.0 2.3 -12.9 \n",
|
| 534 |
+
"4 -0.1 0.0 2.3 340 42.0 2.5 -11.5 \n",
|
| 535 |
+
"5 -0.1 0.0 2.8 50 43.0 2.6 -11.2 \n",
|
| 536 |
+
"6 -0.5 0.0 2.1 20 45.0 2.6 -11.0 \n",
|
| 537 |
+
"7 -0.8 0.0 2.5 340 45.0 2.6 -11.2 \n",
|
| 538 |
+
"8 -0.5 0.0 1.2 360 43.0 2.5 -11.5 \n",
|
| 539 |
+
"9 1.7 0.0 2.1 20 39.0 2.7 -10.8 \n",
|
| 540 |
+
"\n",
|
| 541 |
+
" loc_pressure sea_pressure solarRad ... year month hour \\\n",
|
| 542 |
+
"0 1015.8 1024.6 0.00 ... 2018 1 0 \n",
|
| 543 |
+
"1 1015.5 1024.3 0.00 ... 2018 1 1 \n",
|
| 544 |
+
"2 1015.7 1024.5 0.00 ... 2018 1 2 \n",
|
| 545 |
+
"3 1015.9 1024.7 0.00 ... 2018 1 3 \n",
|
| 546 |
+
"4 1016.0 1024.9 0.00 ... 2018 1 4 \n",
|
| 547 |
+
"5 1016.0 1024.9 0.00 ... 2018 1 5 \n",
|
| 548 |
+
"6 1016.5 1025.4 0.00 ... 2018 1 6 \n",
|
| 549 |
+
"7 1017.1 1026.0 0.00 ... 2018 1 7 \n",
|
| 550 |
+
"8 1017.4 1026.3 0.03 ... 2018 1 8 \n",
|
| 551 |
+
"9 1018.1 1026.9 0.46 ... 2018 1 9 \n",
|
| 552 |
+
"\n",
|
| 553 |
+
" ground_temp - temp_C hour_sin hour_cos month_sin month_cos visi \\\n",
|
| 554 |
+
"0 -5.4 0.000000 1.000000e+00 0.5 0.866025 2000.0 \n",
|
| 555 |
+
"1 -5.4 0.258819 9.659258e-01 0.5 0.866025 2000.0 \n",
|
| 556 |
+
"2 -5.4 0.500000 8.660254e-01 0.5 0.866025 2000.0 \n",
|
| 557 |
+
"3 -5.0 0.707107 7.071068e-01 0.5 0.866025 2000.0 \n",
|
| 558 |
+
"4 -4.3 0.866025 5.000000e-01 0.5 0.866025 2000.0 \n",
|
| 559 |
+
"5 -4.0 0.965926 2.588190e-01 0.5 0.866025 2000.0 \n",
|
| 560 |
+
"6 -4.1 1.000000 6.123234e-17 0.5 0.866025 2000.0 \n",
|
| 561 |
+
"7 -4.5 0.965926 -2.588190e-01 0.5 0.866025 2000.0 \n",
|
| 562 |
+
"8 -4.0 0.866025 -5.000000e-01 0.5 0.866025 2000.0 \n",
|
| 563 |
+
"9 2.8 0.707107 -7.071068e-01 0.5 0.866025 1953.0 \n",
|
| 564 |
+
"\n",
|
| 565 |
+
" multi_class \n",
|
| 566 |
+
"0 2 \n",
|
| 567 |
+
"1 2 \n",
|
| 568 |
+
"2 2 \n",
|
| 569 |
+
"3 2 \n",
|
| 570 |
+
"4 2 \n",
|
| 571 |
+
"5 2 \n",
|
| 572 |
+
"6 2 \n",
|
| 573 |
+
"7 2 \n",
|
| 574 |
+
"8 2 \n",
|
| 575 |
+
"9 2 \n",
|
| 576 |
+
"\n",
|
| 577 |
+
"[10 rows x 30 columns]"
|
| 578 |
+
]
|
| 579 |
+
},
|
| 580 |
+
"execution_count": 14,
|
| 581 |
+
"metadata": {},
|
| 582 |
+
"output_type": "execute_result"
|
| 583 |
+
}
|
| 584 |
+
],
|
| 585 |
+
"source": [
|
| 586 |
+
"df_busan_train.head(10)"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"cell_type": "code",
|
| 591 |
+
"execution_count": 15,
|
| 592 |
+
"metadata": {},
|
| 593 |
+
"outputs": [
|
| 594 |
+
{
|
| 595 |
+
"data": {
|
| 596 |
+
"text/html": [
|
| 597 |
+
"<div>\n",
|
| 598 |
+
"<style scoped>\n",
|
| 599 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 600 |
+
" vertical-align: middle;\n",
|
| 601 |
+
" }\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" .dataframe tbody tr th {\n",
|
| 604 |
+
" vertical-align: top;\n",
|
| 605 |
+
" }\n",
|
| 606 |
+
"\n",
|
| 607 |
+
" .dataframe thead th {\n",
|
| 608 |
+
" text-align: right;\n",
|
| 609 |
+
" }\n",
|
| 610 |
+
"</style>\n",
|
| 611 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 612 |
+
" <thead>\n",
|
| 613 |
+
" <tr style=\"text-align: right;\">\n",
|
| 614 |
+
" <th></th>\n",
|
| 615 |
+
" <th>temp_C</th>\n",
|
| 616 |
+
" <th>precip_mm</th>\n",
|
| 617 |
+
" <th>wind_speed</th>\n",
|
| 618 |
+
" <th>wind_dir</th>\n",
|
| 619 |
+
" <th>hm</th>\n",
|
| 620 |
+
" <th>vap_pressure</th>\n",
|
| 621 |
+
" <th>dewpoint_C</th>\n",
|
| 622 |
+
" <th>loc_pressure</th>\n",
|
| 623 |
+
" <th>sea_pressure</th>\n",
|
| 624 |
+
" <th>solarRad</th>\n",
|
| 625 |
+
" <th>...</th>\n",
|
| 626 |
+
" <th>year</th>\n",
|
| 627 |
+
" <th>month</th>\n",
|
| 628 |
+
" <th>hour</th>\n",
|
| 629 |
+
" <th>ground_temp - temp_C</th>\n",
|
| 630 |
+
" <th>hour_sin</th>\n",
|
| 631 |
+
" <th>hour_cos</th>\n",
|
| 632 |
+
" <th>month_sin</th>\n",
|
| 633 |
+
" <th>month_cos</th>\n",
|
| 634 |
+
" <th>visi</th>\n",
|
| 635 |
+
" <th>multi_class</th>\n",
|
| 636 |
+
" </tr>\n",
|
| 637 |
+
" </thead>\n",
|
| 638 |
+
" <tbody>\n",
|
| 639 |
+
" <tr>\n",
|
| 640 |
+
" <th>26294</th>\n",
|
| 641 |
+
" <td>0.1</td>\n",
|
| 642 |
+
" <td>0.0</td>\n",
|
| 643 |
+
" <td>6.3</td>\n",
|
| 644 |
+
" <td>270</td>\n",
|
| 645 |
+
" <td>37.0</td>\n",
|
| 646 |
+
" <td>2.3</td>\n",
|
| 647 |
+
" <td>-12.9</td>\n",
|
| 648 |
+
" <td>1013.3</td>\n",
|
| 649 |
+
" <td>1022.1</td>\n",
|
| 650 |
+
" <td>2.07</td>\n",
|
| 651 |
+
" <td>...</td>\n",
|
| 652 |
+
" <td>2020</td>\n",
|
| 653 |
+
" <td>12</td>\n",
|
| 654 |
+
" <td>14</td>\n",
|
| 655 |
+
" <td>5.8</td>\n",
|
| 656 |
+
" <td>-0.500000</td>\n",
|
| 657 |
+
" <td>-8.660254e-01</td>\n",
|
| 658 |
+
" <td>-2.449294e-16</td>\n",
|
| 659 |
+
" <td>1.0</td>\n",
|
| 660 |
+
" <td>5000.0</td>\n",
|
| 661 |
+
" <td>2</td>\n",
|
| 662 |
+
" </tr>\n",
|
| 663 |
+
" <tr>\n",
|
| 664 |
+
" <th>26295</th>\n",
|
| 665 |
+
" <td>1.2</td>\n",
|
| 666 |
+
" <td>0.0</td>\n",
|
| 667 |
+
" <td>5.9</td>\n",
|
| 668 |
+
" <td>270</td>\n",
|
| 669 |
+
" <td>35.0</td>\n",
|
| 670 |
+
" <td>2.3</td>\n",
|
| 671 |
+
" <td>-12.6</td>\n",
|
| 672 |
+
" <td>1013.2</td>\n",
|
| 673 |
+
" <td>1022.0</td>\n",
|
| 674 |
+
" <td>1.71</td>\n",
|
| 675 |
+
" <td>...</td>\n",
|
| 676 |
+
" <td>2020</td>\n",
|
| 677 |
+
" <td>12</td>\n",
|
| 678 |
+
" <td>15</td>\n",
|
| 679 |
+
" <td>5.6</td>\n",
|
| 680 |
+
" <td>-0.707107</td>\n",
|
| 681 |
+
" <td>-7.071068e-01</td>\n",
|
| 682 |
+
" <td>-2.449294e-16</td>\n",
|
| 683 |
+
" <td>1.0</td>\n",
|
| 684 |
+
" <td>5000.0</td>\n",
|
| 685 |
+
" <td>2</td>\n",
|
| 686 |
+
" </tr>\n",
|
| 687 |
+
" <tr>\n",
|
| 688 |
+
" <th>26296</th>\n",
|
| 689 |
+
" <td>1.6</td>\n",
|
| 690 |
+
" <td>0.0</td>\n",
|
| 691 |
+
" <td>3.6</td>\n",
|
| 692 |
+
" <td>290</td>\n",
|
| 693 |
+
" <td>34.0</td>\n",
|
| 694 |
+
" <td>2.3</td>\n",
|
| 695 |
+
" <td>-12.6</td>\n",
|
| 696 |
+
" <td>1012.8</td>\n",
|
| 697 |
+
" <td>1021.6</td>\n",
|
| 698 |
+
" <td>1.14</td>\n",
|
| 699 |
+
" <td>...</td>\n",
|
| 700 |
+
" <td>2020</td>\n",
|
| 701 |
+
" <td>12</td>\n",
|
| 702 |
+
" <td>16</td>\n",
|
| 703 |
+
" <td>1.4</td>\n",
|
| 704 |
+
" <td>-0.866025</td>\n",
|
| 705 |
+
" <td>-5.000000e-01</td>\n",
|
| 706 |
+
" <td>-2.449294e-16</td>\n",
|
| 707 |
+
" <td>1.0</td>\n",
|
| 708 |
+
" <td>5000.0</td>\n",
|
| 709 |
+
" <td>2</td>\n",
|
| 710 |
+
" </tr>\n",
|
| 711 |
+
" <tr>\n",
|
| 712 |
+
" <th>26297</th>\n",
|
| 713 |
+
" <td>1.2</td>\n",
|
| 714 |
+
" <td>0.0</td>\n",
|
| 715 |
+
" <td>3.8</td>\n",
|
| 716 |
+
" <td>250</td>\n",
|
| 717 |
+
" <td>38.0</td>\n",
|
| 718 |
+
" <td>2.5</td>\n",
|
| 719 |
+
" <td>-11.5</td>\n",
|
| 720 |
+
" <td>1012.8</td>\n",
|
| 721 |
+
" <td>1021.6</td>\n",
|
| 722 |
+
" <td>0.48</td>\n",
|
| 723 |
+
" <td>...</td>\n",
|
| 724 |
+
" <td>2020</td>\n",
|
| 725 |
+
" <td>12</td>\n",
|
| 726 |
+
" <td>17</td>\n",
|
| 727 |
+
" <td>-0.4</td>\n",
|
| 728 |
+
" <td>-0.965926</td>\n",
|
| 729 |
+
" <td>-2.588190e-01</td>\n",
|
| 730 |
+
" <td>-2.449294e-16</td>\n",
|
| 731 |
+
" <td>1.0</td>\n",
|
| 732 |
+
" <td>5000.0</td>\n",
|
| 733 |
+
" <td>2</td>\n",
|
| 734 |
+
" </tr>\n",
|
| 735 |
+
" <tr>\n",
|
| 736 |
+
" <th>26298</th>\n",
|
| 737 |
+
" <td>0.9</td>\n",
|
| 738 |
+
" <td>0.0</td>\n",
|
| 739 |
+
" <td>3.8</td>\n",
|
| 740 |
+
" <td>270</td>\n",
|
| 741 |
+
" <td>40.0</td>\n",
|
| 742 |
+
" <td>2.6</td>\n",
|
| 743 |
+
" <td>-11.2</td>\n",
|
| 744 |
+
" <td>1013.1</td>\n",
|
| 745 |
+
" <td>1021.9</td>\n",
|
| 746 |
+
" <td>0.02</td>\n",
|
| 747 |
+
" <td>...</td>\n",
|
| 748 |
+
" <td>2020</td>\n",
|
| 749 |
+
" <td>12</td>\n",
|
| 750 |
+
" <td>18</td>\n",
|
| 751 |
+
" <td>-0.8</td>\n",
|
| 752 |
+
" <td>-1.000000</td>\n",
|
| 753 |
+
" <td>-1.836970e-16</td>\n",
|
| 754 |
+
" <td>-2.449294e-16</td>\n",
|
| 755 |
+
" <td>1.0</td>\n",
|
| 756 |
+
" <td>5000.0</td>\n",
|
| 757 |
+
" <td>2</td>\n",
|
| 758 |
+
" </tr>\n",
|
| 759 |
+
" <tr>\n",
|
| 760 |
+
" <th>26299</th>\n",
|
| 761 |
+
" <td>0.6</td>\n",
|
| 762 |
+
" <td>0.0</td>\n",
|
| 763 |
+
" <td>6.2</td>\n",
|
| 764 |
+
" <td>270</td>\n",
|
| 765 |
+
" <td>41.0</td>\n",
|
| 766 |
+
" <td>2.6</td>\n",
|
| 767 |
+
" <td>-11.1</td>\n",
|
| 768 |
+
" <td>1014.0</td>\n",
|
| 769 |
+
" <td>1022.8</td>\n",
|
| 770 |
+
" <td>0.00</td>\n",
|
| 771 |
+
" <td>...</td>\n",
|
| 772 |
+
" <td>2020</td>\n",
|
| 773 |
+
" <td>12</td>\n",
|
| 774 |
+
" <td>19</td>\n",
|
| 775 |
+
" <td>-1.1</td>\n",
|
| 776 |
+
" <td>-0.965926</td>\n",
|
| 777 |
+
" <td>2.588190e-01</td>\n",
|
| 778 |
+
" <td>-2.449294e-16</td>\n",
|
| 779 |
+
" <td>1.0</td>\n",
|
| 780 |
+
" <td>5000.0</td>\n",
|
| 781 |
+
" <td>2</td>\n",
|
| 782 |
+
" </tr>\n",
|
| 783 |
+
" <tr>\n",
|
| 784 |
+
" <th>26300</th>\n",
|
| 785 |
+
" <td>0.1</td>\n",
|
| 786 |
+
" <td>0.0</td>\n",
|
| 787 |
+
" <td>6.0</td>\n",
|
| 788 |
+
" <td>270</td>\n",
|
| 789 |
+
" <td>44.0</td>\n",
|
| 790 |
+
" <td>2.7</td>\n",
|
| 791 |
+
" <td>-10.7</td>\n",
|
| 792 |
+
" <td>1014.8</td>\n",
|
| 793 |
+
" <td>1023.6</td>\n",
|
| 794 |
+
" <td>0.00</td>\n",
|
| 795 |
+
" <td>...</td>\n",
|
| 796 |
+
" <td>2020</td>\n",
|
| 797 |
+
" <td>12</td>\n",
|
| 798 |
+
" <td>20</td>\n",
|
| 799 |
+
" <td>-0.9</td>\n",
|
| 800 |
+
" <td>-0.866025</td>\n",
|
| 801 |
+
" <td>5.000000e-01</td>\n",
|
| 802 |
+
" <td>-2.449294e-16</td>\n",
|
| 803 |
+
" <td>1.0</td>\n",
|
| 804 |
+
" <td>5000.0</td>\n",
|
| 805 |
+
" <td>2</td>\n",
|
| 806 |
+
" </tr>\n",
|
| 807 |
+
" <tr>\n",
|
| 808 |
+
" <th>26301</th>\n",
|
| 809 |
+
" <td>-0.2</td>\n",
|
| 810 |
+
" <td>0.0</td>\n",
|
| 811 |
+
" <td>5.0</td>\n",
|
| 812 |
+
" <td>290</td>\n",
|
| 813 |
+
" <td>48.0</td>\n",
|
| 814 |
+
" <td>2.9</td>\n",
|
| 815 |
+
" <td>-9.9</td>\n",
|
| 816 |
+
" <td>1014.6</td>\n",
|
| 817 |
+
" <td>1023.4</td>\n",
|
| 818 |
+
" <td>0.00</td>\n",
|
| 819 |
+
" <td>...</td>\n",
|
| 820 |
+
" <td>2020</td>\n",
|
| 821 |
+
" <td>12</td>\n",
|
| 822 |
+
" <td>21</td>\n",
|
| 823 |
+
" <td>-0.8</td>\n",
|
| 824 |
+
" <td>-0.707107</td>\n",
|
| 825 |
+
" <td>7.071068e-01</td>\n",
|
| 826 |
+
" <td>-2.449294e-16</td>\n",
|
| 827 |
+
" <td>1.0</td>\n",
|
| 828 |
+
" <td>5000.0</td>\n",
|
| 829 |
+
" <td>2</td>\n",
|
| 830 |
+
" </tr>\n",
|
| 831 |
+
" <tr>\n",
|
| 832 |
+
" <th>26302</th>\n",
|
| 833 |
+
" <td>-0.7</td>\n",
|
| 834 |
+
" <td>0.0</td>\n",
|
| 835 |
+
" <td>2.7</td>\n",
|
| 836 |
+
" <td>270</td>\n",
|
| 837 |
+
" <td>51.0</td>\n",
|
| 838 |
+
" <td>3.0</td>\n",
|
| 839 |
+
" <td>-9.6</td>\n",
|
| 840 |
+
" <td>1014.8</td>\n",
|
| 841 |
+
" <td>1023.6</td>\n",
|
| 842 |
+
" <td>0.00</td>\n",
|
| 843 |
+
" <td>...</td>\n",
|
| 844 |
+
" <td>2020</td>\n",
|
| 845 |
+
" <td>12</td>\n",
|
| 846 |
+
" <td>22</td>\n",
|
| 847 |
+
" <td>-0.6</td>\n",
|
| 848 |
+
" <td>-0.500000</td>\n",
|
| 849 |
+
" <td>8.660254e-01</td>\n",
|
| 850 |
+
" <td>-2.449294e-16</td>\n",
|
| 851 |
+
" <td>1.0</td>\n",
|
| 852 |
+
" <td>5000.0</td>\n",
|
| 853 |
+
" <td>2</td>\n",
|
| 854 |
+
" </tr>\n",
|
| 855 |
+
" <tr>\n",
|
| 856 |
+
" <th>26303</th>\n",
|
| 857 |
+
" <td>-0.7</td>\n",
|
| 858 |
+
" <td>0.0</td>\n",
|
| 859 |
+
" <td>3.8</td>\n",
|
| 860 |
+
" <td>250</td>\n",
|
| 861 |
+
" <td>55.0</td>\n",
|
| 862 |
+
" <td>3.2</td>\n",
|
| 863 |
+
" <td>-8.6</td>\n",
|
| 864 |
+
" <td>1015.1</td>\n",
|
| 865 |
+
" <td>1024.0</td>\n",
|
| 866 |
+
" <td>0.00</td>\n",
|
| 867 |
+
" <td>...</td>\n",
|
| 868 |
+
" <td>2020</td>\n",
|
| 869 |
+
" <td>12</td>\n",
|
| 870 |
+
" <td>23</td>\n",
|
| 871 |
+
" <td>-0.6</td>\n",
|
| 872 |
+
" <td>-0.258819</td>\n",
|
| 873 |
+
" <td>9.659258e-01</td>\n",
|
| 874 |
+
" <td>-2.449294e-16</td>\n",
|
| 875 |
+
" <td>1.0</td>\n",
|
| 876 |
+
" <td>5000.0</td>\n",
|
| 877 |
+
" <td>2</td>\n",
|
| 878 |
+
" </tr>\n",
|
| 879 |
+
" </tbody>\n",
|
| 880 |
+
"</table>\n",
|
| 881 |
+
"<p>10 rows ร 30 columns</p>\n",
|
| 882 |
+
"</div>"
|
| 883 |
+
],
|
| 884 |
+
"text/plain": [
|
| 885 |
+
" temp_C precip_mm wind_speed wind_dir hm vap_pressure dewpoint_C \\\n",
|
| 886 |
+
"26294 0.1 0.0 6.3 270 37.0 2.3 -12.9 \n",
|
| 887 |
+
"26295 1.2 0.0 5.9 270 35.0 2.3 -12.6 \n",
|
| 888 |
+
"26296 1.6 0.0 3.6 290 34.0 2.3 -12.6 \n",
|
| 889 |
+
"26297 1.2 0.0 3.8 250 38.0 2.5 -11.5 \n",
|
| 890 |
+
"26298 0.9 0.0 3.8 270 40.0 2.6 -11.2 \n",
|
| 891 |
+
"26299 0.6 0.0 6.2 270 41.0 2.6 -11.1 \n",
|
| 892 |
+
"26300 0.1 0.0 6.0 270 44.0 2.7 -10.7 \n",
|
| 893 |
+
"26301 -0.2 0.0 5.0 290 48.0 2.9 -9.9 \n",
|
| 894 |
+
"26302 -0.7 0.0 2.7 270 51.0 3.0 -9.6 \n",
|
| 895 |
+
"26303 -0.7 0.0 3.8 250 55.0 3.2 -8.6 \n",
|
| 896 |
+
"\n",
|
| 897 |
+
" loc_pressure sea_pressure solarRad ... year month hour \\\n",
|
| 898 |
+
"26294 1013.3 1022.1 2.07 ... 2020 12 14 \n",
|
| 899 |
+
"26295 1013.2 1022.0 1.71 ... 2020 12 15 \n",
|
| 900 |
+
"26296 1012.8 1021.6 1.14 ... 2020 12 16 \n",
|
| 901 |
+
"26297 1012.8 1021.6 0.48 ... 2020 12 17 \n",
|
| 902 |
+
"26298 1013.1 1021.9 0.02 ... 2020 12 18 \n",
|
| 903 |
+
"26299 1014.0 1022.8 0.00 ... 2020 12 19 \n",
|
| 904 |
+
"26300 1014.8 1023.6 0.00 ... 2020 12 20 \n",
|
| 905 |
+
"26301 1014.6 1023.4 0.00 ... 2020 12 21 \n",
|
| 906 |
+
"26302 1014.8 1023.6 0.00 ... 2020 12 22 \n",
|
| 907 |
+
"26303 1015.1 1024.0 0.00 ... 2020 12 23 \n",
|
| 908 |
+
"\n",
|
| 909 |
+
" ground_temp - temp_C hour_sin hour_cos month_sin month_cos \\\n",
|
| 910 |
+
"26294 5.8 -0.500000 -8.660254e-01 -2.449294e-16 1.0 \n",
|
| 911 |
+
"26295 5.6 -0.707107 -7.071068e-01 -2.449294e-16 1.0 \n",
|
| 912 |
+
"26296 1.4 -0.866025 -5.000000e-01 -2.449294e-16 1.0 \n",
|
| 913 |
+
"26297 -0.4 -0.965926 -2.588190e-01 -2.449294e-16 1.0 \n",
|
| 914 |
+
"26298 -0.8 -1.000000 -1.836970e-16 -2.449294e-16 1.0 \n",
|
| 915 |
+
"26299 -1.1 -0.965926 2.588190e-01 -2.449294e-16 1.0 \n",
|
| 916 |
+
"26300 -0.9 -0.866025 5.000000e-01 -2.449294e-16 1.0 \n",
|
| 917 |
+
"26301 -0.8 -0.707107 7.071068e-01 -2.449294e-16 1.0 \n",
|
| 918 |
+
"26302 -0.6 -0.500000 8.660254e-01 -2.449294e-16 1.0 \n",
|
| 919 |
+
"26303 -0.6 -0.258819 9.659258e-01 -2.449294e-16 1.0 \n",
|
| 920 |
+
"\n",
|
| 921 |
+
" visi multi_class \n",
|
| 922 |
+
"26294 5000.0 2 \n",
|
| 923 |
+
"26295 5000.0 2 \n",
|
| 924 |
+
"26296 5000.0 2 \n",
|
| 925 |
+
"26297 5000.0 2 \n",
|
| 926 |
+
"26298 5000.0 2 \n",
|
| 927 |
+
"26299 5000.0 2 \n",
|
| 928 |
+
"26300 5000.0 2 \n",
|
| 929 |
+
"26301 5000.0 2 \n",
|
| 930 |
+
"26302 5000.0 2 \n",
|
| 931 |
+
"26303 5000.0 2 \n",
|
| 932 |
+
"\n",
|
| 933 |
+
"[10 rows x 30 columns]"
|
| 934 |
+
]
|
| 935 |
+
},
|
| 936 |
+
"execution_count": 15,
|
| 937 |
+
"metadata": {},
|
| 938 |
+
"output_type": "execute_result"
|
| 939 |
+
}
|
| 940 |
+
],
|
| 941 |
+
"source": [
|
| 942 |
+
"df_busan_train.tail(10)"
|
| 943 |
+
]
|
| 944 |
+
},
|
| 945 |
+
{
|
| 946 |
+
"cell_type": "code",
|
| 947 |
+
"execution_count": 16,
|
| 948 |
+
"metadata": {},
|
| 949 |
+
"outputs": [
|
| 950 |
+
{
|
| 951 |
+
"name": "stdout",
|
| 952 |
+
"output_type": "stream",
|
| 953 |
+
"text": [
|
| 954 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 955 |
+
"Index: 26304 entries, 0 to 26303\n",
|
| 956 |
+
"Data columns (total 30 columns):\n",
|
| 957 |
+
" # Column Non-Null Count Dtype \n",
|
| 958 |
+
"--- ------ -------------- ----- \n",
|
| 959 |
+
" 0 temp_C 26304 non-null float64 \n",
|
| 960 |
+
" 1 precip_mm 26304 non-null float64 \n",
|
| 961 |
+
" 2 wind_speed 26304 non-null float64 \n",
|
| 962 |
+
" 3 wind_dir 26304 non-null category\n",
|
| 963 |
+
" 4 hm 26304 non-null float64 \n",
|
| 964 |
+
" 5 vap_pressure 26304 non-null float64 \n",
|
| 965 |
+
" 6 dewpoint_C 26304 non-null float64 \n",
|
| 966 |
+
" 7 loc_pressure 26304 non-null float64 \n",
|
| 967 |
+
" 8 sea_pressure 26304 non-null float64 \n",
|
| 968 |
+
" 9 solarRad 26304 non-null float64 \n",
|
| 969 |
+
" 10 snow_cm 26304 non-null float64 \n",
|
| 970 |
+
" 11 cloudcover 26304 non-null category\n",
|
| 971 |
+
" 12 lm_cloudcover 26304 non-null category\n",
|
| 972 |
+
" 13 low_cloudbase 26304 non-null float64 \n",
|
| 973 |
+
" 14 groundtemp 26304 non-null float64 \n",
|
| 974 |
+
" 15 O3 26304 non-null float64 \n",
|
| 975 |
+
" 16 NO2 26304 non-null float64 \n",
|
| 976 |
+
" 17 PM10 26304 non-null float64 \n",
|
| 977 |
+
" 18 PM25 26304 non-null float64 \n",
|
| 978 |
+
" 19 binary_class 26304 non-null int64 \n",
|
| 979 |
+
" 20 year 26304 non-null int64 \n",
|
| 980 |
+
" 21 month 26304 non-null int64 \n",
|
| 981 |
+
" 22 hour 26304 non-null int64 \n",
|
| 982 |
+
" 23 ground_temp - temp_C 26304 non-null float64 \n",
|
| 983 |
+
" 24 hour_sin 26304 non-null float64 \n",
|
| 984 |
+
" 25 hour_cos 26304 non-null float64 \n",
|
| 985 |
+
" 26 month_sin 26304 non-null float64 \n",
|
| 986 |
+
" 27 month_cos 26304 non-null float64 \n",
|
| 987 |
+
" 28 visi 26304 non-null float64 \n",
|
| 988 |
+
" 29 multi_class 26304 non-null int64 \n",
|
| 989 |
+
"dtypes: category(3), float64(22), int64(5)\n",
|
| 990 |
+
"memory usage: 5.7 MB\n"
|
| 991 |
+
]
|
| 992 |
+
}
|
| 993 |
+
],
|
| 994 |
+
"source": [
|
| 995 |
+
"df_busan_train.info()"
|
| 996 |
+
]
|
| 997 |
+
},
|
| 998 |
+
{
|
| 999 |
+
"cell_type": "code",
|
| 1000 |
+
"execution_count": 17,
|
| 1001 |
+
"metadata": {},
|
| 1002 |
+
"outputs": [],
|
| 1003 |
+
"source": [
|
| 1004 |
+
"df_seoul_train.to_csv(\"../../data/data_for_modeling/seoul_train.csv\")\n",
|
| 1005 |
+
"df_seoul_test.to_csv(\"../../data/data_for_modeling/seoul_test.csv\")\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
"df_busan_train.to_csv(\"../../data/data_for_modeling/busan_train.csv\")\n",
|
| 1008 |
+
"df_busan_test.to_csv(\"../../data/data_for_modeling/busan_test.csv\")\n",
|
| 1009 |
+
"\n",
|
| 1010 |
+
"df_incheon_train.to_csv(\"../../data/data_for_modeling/incheon_train.csv\")\n",
|
| 1011 |
+
"df_incheon_test.to_csv(\"../../data/data_for_modeling/incheon_test.csv\")\n",
|
| 1012 |
+
"\n",
|
| 1013 |
+
"df_daegu_train.to_csv(\"../../data/data_for_modeling/daegu_train.csv\")\n",
|
| 1014 |
+
"df_daegu_test.to_csv(\"../../data/data_for_modeling/daegu_test.csv\")\n",
|
| 1015 |
+
"\n",
|
| 1016 |
+
"df_daejeon_train.to_csv(\"../../data/data_for_modeling/daejeon_train.csv\")\n",
|
| 1017 |
+
"df_daejeon_test.to_csv(\"../../data/data_for_modeling/daejeon_test.csv\")\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
"df_gwangju_train.to_csv(\"../../data/data_for_modeling/gwangju_train.csv\")\n",
|
| 1020 |
+
"df_gwangju_test.to_csv(\"../../data/data_for_modeling/gwangju_test.csv\")\n",
|
| 1021 |
+
"\n",
|
| 1022 |
+
"df_seoul_train = pd.read_csv(\"../../data/data_for_modeling/seoul_train.csv\")\n",
|
| 1023 |
+
"df_seoul_test = pd.read_csv(\"../../data/data_for_modeling/seoul_test.csv\")\n"
|
| 1024 |
+
]
|
| 1025 |
+
},
|
| 1026 |
+
{
|
| 1027 |
+
"cell_type": "code",
|
| 1028 |
+
"execution_count": 18,
|
| 1029 |
+
"metadata": {},
|
| 1030 |
+
"outputs": [
|
| 1031 |
+
{
|
| 1032 |
+
"name": "stdout",
|
| 1033 |
+
"output_type": "stream",
|
| 1034 |
+
"text": [
|
| 1035 |
+
"Counter({2: 8266, 1: 481, 0: 13})\n",
|
| 1036 |
+
"Counter({2: 23686, 1: 2579, 0: 39})\n",
|
| 1037 |
+
"Counter({2: 8455, 1: 281, 0: 24})\n",
|
| 1038 |
+
"Counter({2: 24694, 1: 1516, 0: 94})\n",
|
| 1039 |
+
"Counter({2: 7373, 1: 1205, 0: 182})\n",
|
| 1040 |
+
"Counter({2: 21893, 1: 3892, 0: 519})\n",
|
| 1041 |
+
"Counter({2: 8631, 1: 128, 0: 1})\n",
|
| 1042 |
+
"Counter({2: 25149, 1: 1107, 0: 48})\n",
|
| 1043 |
+
"Counter({2: 8089, 1: 618, 0: 53})\n",
|
| 1044 |
+
"Counter({2: 23471, 1: 2660, 0: 173})\n",
|
| 1045 |
+
"Counter({2: 8087, 1: 643, 0: 30})\n",
|
| 1046 |
+
"Counter({2: 23798, 1: 2411, 0: 95})\n"
|
| 1047 |
+
]
|
| 1048 |
+
}
|
| 1049 |
+
],
|
| 1050 |
+
"source": [
|
| 1051 |
+
"print(Counter(df_seoul_test['multi_class']))\n",
|
| 1052 |
+
"print(Counter(df_seoul_train['multi_class']))\n",
|
| 1053 |
+
"\n",
|
| 1054 |
+
"print(Counter(df_busan_test['multi_class']))\n",
|
| 1055 |
+
"print(Counter(df_busan_train['multi_class']))\n",
|
| 1056 |
+
"\n",
|
| 1057 |
+
"print(Counter(df_incheon_test['multi_class']))\n",
|
| 1058 |
+
"print(Counter(df_incheon_train['multi_class']))\n",
|
| 1059 |
+
"\n",
|
| 1060 |
+
"print(Counter(df_daegu_test['multi_class']))\n",
|
| 1061 |
+
"print(Counter(df_daegu_train['multi_class']))\n",
|
| 1062 |
+
"\n",
|
| 1063 |
+
"print(Counter(df_daejeon_test['multi_class']))\n",
|
| 1064 |
+
"print(Counter(df_daejeon_train['multi_class']))\n",
|
| 1065 |
+
"\n",
|
| 1066 |
+
"print(Counter(df_gwangju_test['multi_class']))\n",
|
| 1067 |
+
"print(Counter(df_gwangju_train['multi_class']))"
|
| 1068 |
+
]
|
| 1069 |
+
},
|
| 1070 |
+
{
|
| 1071 |
+
"cell_type": "code",
|
| 1072 |
+
"execution_count": null,
|
| 1073 |
+
"metadata": {},
|
| 1074 |
+
"outputs": [],
|
| 1075 |
+
"source": []
|
| 1076 |
+
}
|
| 1077 |
+
],
|
| 1078 |
+
"metadata": {
|
| 1079 |
+
"kernelspec": {
|
| 1080 |
+
"display_name": "Python 3",
|
| 1081 |
+
"language": "python",
|
| 1082 |
+
"name": "python3"
|
| 1083 |
+
},
|
| 1084 |
+
"language_info": {
|
| 1085 |
+
"codemirror_mode": {
|
| 1086 |
+
"name": "ipython",
|
| 1087 |
+
"version": 3
|
| 1088 |
+
},
|
| 1089 |
+
"file_extension": ".py",
|
| 1090 |
+
"mimetype": "text/x-python",
|
| 1091 |
+
"name": "python",
|
| 1092 |
+
"nbconvert_exporter": "python",
|
| 1093 |
+
"pygments_lexer": "ipython3",
|
| 1094 |
+
"version": "3.8.10"
|
| 1095 |
+
}
|
| 1096 |
+
},
|
| 1097 |
+
"nbformat": 4,
|
| 1098 |
+
"nbformat_minor": 2
|
| 1099 |
+
}
|
Analysis_code/2.make_oversample_data/gpu0.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analysis_code/2.make_oversample_data/gpu1.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_1.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import optuna
|
| 5 |
+
from ctgan import CTGAN
|
| 6 |
+
import torch
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# ==================== ์์ ์ ์ ====================
|
| 10 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 11 |
+
TRAIN_YEARS = [2018, 2019]
|
| 12 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 13 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 14 |
+
RANDOM_STATE = 42
|
| 15 |
+
|
| 16 |
+
# Optuna ์ต์ ํ ์ค์
|
| 17 |
+
CLASS_0_TRIALS = 50
|
| 18 |
+
CLASS_1_TRIALS = 30
|
| 19 |
+
|
| 20 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 21 |
+
CLASS_0_HP_RANGES = {
|
| 22 |
+
'embedding_dim': (64, 128),
|
| 23 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 24 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'pac': [4, 8],
|
| 26 |
+
'batch_size': [64, 128, 256],
|
| 27 |
+
'discriminator_steps': (1, 3)
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
CLASS_1_HP_RANGES = {
|
| 31 |
+
'embedding_dim': (128, 512),
|
| 32 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 33 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'pac': [4, 8],
|
| 35 |
+
'batch_size': [256, 512, 1024],
|
| 36 |
+
'discriminator_steps': (1, 5)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 40 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 41 |
+
|
| 42 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 43 |
+
|
| 44 |
+
def setup_environment():
|
| 45 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
print(f"Using device: {device}")
|
| 48 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 49 |
+
return device
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 53 |
+
"""
|
| 54 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 58 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 62 |
+
"""
|
| 63 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 64 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 65 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 66 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 67 |
+
|
| 68 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 69 |
+
y = data['multi_class']
|
| 70 |
+
|
| 71 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 72 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 73 |
+
|
| 74 |
+
return data, X, y
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 78 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 79 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 83 |
+
categorical_features: list,
|
| 84 |
+
hp_ranges: dict) -> callable:
|
| 85 |
+
"""
|
| 86 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 90 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 91 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 92 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Optuna ๋ชฉ์ ํจ์
|
| 96 |
+
"""
|
| 97 |
+
class_data = data[data['multi_class'] == class_label]
|
| 98 |
+
|
| 99 |
+
def objective(trial):
|
| 100 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 101 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 102 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 103 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 104 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 105 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 106 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 107 |
+
|
| 108 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 109 |
+
ctgan = CTGAN(
|
| 110 |
+
embedding_dim=embedding_dim,
|
| 111 |
+
generator_dim=generator_dim,
|
| 112 |
+
discriminator_dim=discriminator_dim,
|
| 113 |
+
batch_size=batch_size,
|
| 114 |
+
discriminator_steps=discriminator_steps,
|
| 115 |
+
pac=pac
|
| 116 |
+
)
|
| 117 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 118 |
+
|
| 119 |
+
# ๋ชจ๋ธ ํ์ต
|
| 120 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 121 |
+
|
| 122 |
+
# ์ํ ์์ฑ
|
| 123 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 124 |
+
|
| 125 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 126 |
+
real_visi = class_data['visi']
|
| 127 |
+
generated_visi = generated_data['visi']
|
| 128 |
+
|
| 129 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 130 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 131 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 132 |
+
return -mse
|
| 133 |
+
|
| 134 |
+
return objective
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 138 |
+
categorical_features: list,
|
| 139 |
+
hp_ranges: dict, n_trials: int,
|
| 140 |
+
target_samples: int) -> tuple:
|
| 141 |
+
"""
|
| 142 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 150 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 154 |
+
"""
|
| 155 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 156 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 157 |
+
|
| 158 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 159 |
+
study = optuna.create_study(direction="maximize")
|
| 160 |
+
study.optimize(objective, n_trials=n_trials)
|
| 161 |
+
|
| 162 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 163 |
+
best_params = study.best_params
|
| 164 |
+
ctgan = CTGAN(
|
| 165 |
+
embedding_dim=best_params["embedding_dim"],
|
| 166 |
+
generator_dim=best_params["generator_dim"],
|
| 167 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 168 |
+
batch_size=best_params["batch_size"],
|
| 169 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 170 |
+
pac=best_params["pac"]
|
| 171 |
+
)
|
| 172 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 173 |
+
|
| 174 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 175 |
+
class_data = data[data['multi_class'] == class_label]
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
generated_samples = ctgan.sample(target_samples)
|
| 178 |
+
|
| 179 |
+
return generated_samples, ctgan
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 183 |
+
"""
|
| 184 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 191 |
+
"""
|
| 192 |
+
df = df.copy()
|
| 193 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 194 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 195 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 197 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 199 |
+
return df
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 203 |
+
"""
|
| 204 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 208 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 210 |
+
"""
|
| 211 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 212 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 213 |
+
|
| 214 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 215 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 216 |
+
|
| 217 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 218 |
+
train_data = X.copy()
|
| 219 |
+
train_data['multi_class'] = y
|
| 220 |
+
|
| 221 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 222 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 223 |
+
|
| 224 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 225 |
+
count_class_0 = (y == 0).sum()
|
| 226 |
+
count_class_1 = (y == 1).sum()
|
| 227 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 228 |
+
|
| 229 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 230 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 231 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 232 |
+
train_data, 0, categorical_features,
|
| 233 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 237 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 238 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 239 |
+
train_data, 1, categorical_features,
|
| 240 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 244 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 247 |
+
model_path_0 = model_save_dir / f'ctgan_only_10000_1_{region_name}_class0.pkl'
|
| 248 |
+
ctgan_model_0.save(str(model_path_0))
|
| 249 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 250 |
+
|
| 251 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 252 |
+
model_path_1 = model_save_dir / f'ctgan_only_10000_1_{region_name}_class1.pkl'
|
| 253 |
+
ctgan_model_1.save(str(model_path_1))
|
| 254 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 255 |
+
|
| 256 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 257 |
+
well_generated_0 = generated_0[
|
| 258 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 259 |
+
]
|
| 260 |
+
well_generated_1 = generated_1[
|
| 261 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 265 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 266 |
+
augmented_only = add_derived_features(augmented_only)
|
| 267 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 268 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 269 |
+
output_path_obj = Path(output_path)
|
| 270 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 271 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 272 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 273 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 274 |
+
|
| 275 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 276 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 277 |
+
|
| 278 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 279 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 280 |
+
|
| 281 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 282 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 283 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 284 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 285 |
+
|
| 286 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 287 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 288 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 289 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 290 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 291 |
+
|
| 292 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 293 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 294 |
+
|
| 295 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 296 |
+
final_data.to_csv(output_path, index=False)
|
| 297 |
+
|
| 298 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 299 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 300 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 301 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 302 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
setup_environment()
|
| 309 |
+
|
| 310 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 311 |
+
output_paths = [f'../../../data/data_oversampled/ctgan10000/ctgan10000_1_{region}.csv' for region in REGIONS]
|
| 312 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 313 |
+
|
| 314 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 315 |
+
process_region(file_path, output_path, model_save_dir)
|
| 316 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_2.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_10000_2_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_10000_2_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan10000/ctgan10000_2_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_10000_3.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2019, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_10000_3_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_10000_3_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan10000/ctgan10000_3_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_1.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import optuna
|
| 5 |
+
from ctgan import CTGAN
|
| 6 |
+
import torch
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# ==================== ์์ ์ ์ ====================
|
| 10 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 11 |
+
TRAIN_YEARS = [2018, 2019]
|
| 12 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 13 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 14 |
+
RANDOM_STATE = 42
|
| 15 |
+
|
| 16 |
+
# Optuna ์ต์ ํ ์ค์
|
| 17 |
+
CLASS_0_TRIALS = 50
|
| 18 |
+
CLASS_1_TRIALS = 30
|
| 19 |
+
|
| 20 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 21 |
+
CLASS_0_HP_RANGES = {
|
| 22 |
+
'embedding_dim': (64, 128),
|
| 23 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 24 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'pac': [4, 8],
|
| 26 |
+
'batch_size': [64, 128, 256],
|
| 27 |
+
'discriminator_steps': (1, 3)
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
CLASS_1_HP_RANGES = {
|
| 31 |
+
'embedding_dim': (128, 512),
|
| 32 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 33 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'pac': [4, 8],
|
| 35 |
+
'batch_size': [256, 512, 1024],
|
| 36 |
+
'discriminator_steps': (1, 5)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 40 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 41 |
+
|
| 42 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 43 |
+
|
| 44 |
+
def setup_environment():
|
| 45 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
print(f"Using device: {device}")
|
| 48 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 49 |
+
return device
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 53 |
+
"""
|
| 54 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 58 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 62 |
+
"""
|
| 63 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 64 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 65 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 66 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 67 |
+
|
| 68 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 69 |
+
y = data['multi_class']
|
| 70 |
+
|
| 71 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 72 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 73 |
+
|
| 74 |
+
return data, X, y
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 78 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 79 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 83 |
+
categorical_features: list,
|
| 84 |
+
hp_ranges: dict) -> callable:
|
| 85 |
+
"""
|
| 86 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 90 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 91 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 92 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Optuna ๋ชฉ์ ํจ์
|
| 96 |
+
"""
|
| 97 |
+
class_data = data[data['multi_class'] == class_label]
|
| 98 |
+
|
| 99 |
+
def objective(trial):
|
| 100 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 101 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 102 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 103 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 104 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 105 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 106 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 107 |
+
|
| 108 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 109 |
+
ctgan = CTGAN(
|
| 110 |
+
embedding_dim=embedding_dim,
|
| 111 |
+
generator_dim=generator_dim,
|
| 112 |
+
discriminator_dim=discriminator_dim,
|
| 113 |
+
batch_size=batch_size,
|
| 114 |
+
discriminator_steps=discriminator_steps,
|
| 115 |
+
pac=pac
|
| 116 |
+
)
|
| 117 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 118 |
+
|
| 119 |
+
# ๋ชจ๋ธ ํ์ต
|
| 120 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 121 |
+
|
| 122 |
+
# ์ํ ์์ฑ
|
| 123 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 124 |
+
|
| 125 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 126 |
+
real_visi = class_data['visi']
|
| 127 |
+
generated_visi = generated_data['visi']
|
| 128 |
+
|
| 129 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 130 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 131 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 132 |
+
return -mse
|
| 133 |
+
|
| 134 |
+
return objective
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 138 |
+
categorical_features: list,
|
| 139 |
+
hp_ranges: dict, n_trials: int,
|
| 140 |
+
target_samples: int) -> tuple:
|
| 141 |
+
"""
|
| 142 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 150 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 154 |
+
"""
|
| 155 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 156 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 157 |
+
|
| 158 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 159 |
+
study = optuna.create_study(direction="maximize")
|
| 160 |
+
study.optimize(objective, n_trials=n_trials)
|
| 161 |
+
|
| 162 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 163 |
+
best_params = study.best_params
|
| 164 |
+
ctgan = CTGAN(
|
| 165 |
+
embedding_dim=best_params["embedding_dim"],
|
| 166 |
+
generator_dim=best_params["generator_dim"],
|
| 167 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 168 |
+
batch_size=best_params["batch_size"],
|
| 169 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 170 |
+
pac=best_params["pac"]
|
| 171 |
+
)
|
| 172 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 173 |
+
|
| 174 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 175 |
+
class_data = data[data['multi_class'] == class_label]
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
generated_samples = ctgan.sample(target_samples)
|
| 178 |
+
|
| 179 |
+
return generated_samples, ctgan
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 183 |
+
"""
|
| 184 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 191 |
+
"""
|
| 192 |
+
df = df.copy()
|
| 193 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 194 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 195 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 197 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 199 |
+
return df
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 203 |
+
"""
|
| 204 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 208 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 210 |
+
"""
|
| 211 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 212 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 213 |
+
|
| 214 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 215 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 216 |
+
|
| 217 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 218 |
+
train_data = X.copy()
|
| 219 |
+
train_data['multi_class'] = y
|
| 220 |
+
|
| 221 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 222 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 223 |
+
|
| 224 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 225 |
+
count_class_0 = (y == 0).sum()
|
| 226 |
+
count_class_1 = (y == 1).sum()
|
| 227 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 228 |
+
|
| 229 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 230 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 231 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 232 |
+
train_data, 0, categorical_features,
|
| 233 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 237 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 238 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 239 |
+
train_data, 1, categorical_features,
|
| 240 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 244 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 247 |
+
model_path_0 = model_save_dir / f'ctgan_only_20000_1_{region_name}_class0.pkl'
|
| 248 |
+
ctgan_model_0.save(str(model_path_0))
|
| 249 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 250 |
+
|
| 251 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 252 |
+
model_path_1 = model_save_dir / f'ctgan_only_20000_1_{region_name}_class1.pkl'
|
| 253 |
+
ctgan_model_1.save(str(model_path_1))
|
| 254 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 255 |
+
|
| 256 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 257 |
+
well_generated_0 = generated_0[
|
| 258 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 259 |
+
]
|
| 260 |
+
well_generated_1 = generated_1[
|
| 261 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 265 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 266 |
+
augmented_only = add_derived_features(augmented_only)
|
| 267 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 268 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 269 |
+
output_path_obj = Path(output_path)
|
| 270 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 271 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 272 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 273 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 274 |
+
|
| 275 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 276 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 277 |
+
|
| 278 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 279 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 280 |
+
|
| 281 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 282 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 283 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 284 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 285 |
+
|
| 286 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 287 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 288 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 289 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 290 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 291 |
+
|
| 292 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 293 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 294 |
+
|
| 295 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 296 |
+
final_data.to_csv(output_path, index=False)
|
| 297 |
+
|
| 298 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 299 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 300 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 301 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 302 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
setup_environment()
|
| 309 |
+
|
| 310 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 311 |
+
output_paths = [f'../../../data/data_oversampled/ctgan20000/ctgan20000_1_{region}.csv' for region in REGIONS]
|
| 312 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 313 |
+
|
| 314 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 315 |
+
process_region(file_path, output_path, model_save_dir)
|
| 316 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_2.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_20000_2_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_20000_2_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan20000/ctgan20000_2_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_20000_3.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2019, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_20000_3_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_20000_3_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan20000/ctgan20000_3_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_1.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2019]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_7000_1_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_7000_1_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan7000/ctgan7000_1_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_2.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_7000_2_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_7000_2_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan7000/ctgan7000_2_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/only_ctgan/ctgan_sample_7000_3.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2019, 2020]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 80 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 84 |
+
categorical_features: list,
|
| 85 |
+
hp_ranges: dict) -> callable:
|
| 86 |
+
"""
|
| 87 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 91 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 92 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 93 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Optuna ๋ชฉ์ ํจ์
|
| 97 |
+
"""
|
| 98 |
+
class_data = data[data['multi_class'] == class_label]
|
| 99 |
+
|
| 100 |
+
def objective(trial):
|
| 101 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 102 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 103 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 104 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 105 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 106 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 107 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 108 |
+
|
| 109 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 110 |
+
ctgan = CTGAN(
|
| 111 |
+
embedding_dim=embedding_dim,
|
| 112 |
+
generator_dim=generator_dim,
|
| 113 |
+
discriminator_dim=discriminator_dim,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
discriminator_steps=discriminator_steps,
|
| 116 |
+
pac=pac
|
| 117 |
+
)
|
| 118 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 119 |
+
|
| 120 |
+
# ๋ชจ๋ธ ํ์ต
|
| 121 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 122 |
+
|
| 123 |
+
# ์ํ ์์ฑ
|
| 124 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 125 |
+
|
| 126 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 127 |
+
real_visi = class_data['visi']
|
| 128 |
+
generated_visi = generated_data['visi']
|
| 129 |
+
|
| 130 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 131 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 132 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 133 |
+
return -mse
|
| 134 |
+
|
| 135 |
+
return objective
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict, n_trials: int,
|
| 141 |
+
target_samples: int) -> tuple:
|
| 142 |
+
"""
|
| 143 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 147 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 148 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 149 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 150 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 151 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 155 |
+
"""
|
| 156 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 157 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 158 |
+
|
| 159 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 160 |
+
study = optuna.create_study(direction="maximize")
|
| 161 |
+
study.optimize(objective, n_trials=n_trials)
|
| 162 |
+
|
| 163 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 164 |
+
best_params = study.best_params
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=best_params["embedding_dim"],
|
| 167 |
+
generator_dim=best_params["generator_dim"],
|
| 168 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 169 |
+
batch_size=best_params["batch_size"],
|
| 170 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 171 |
+
pac=best_params["pac"]
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 176 |
+
class_data = data[data['multi_class'] == class_label]
|
| 177 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 178 |
+
generated_samples = ctgan.sample(target_samples)
|
| 179 |
+
|
| 180 |
+
return generated_samples, ctgan
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 192 |
+
"""
|
| 193 |
+
df = df.copy()
|
| 194 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 195 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 196 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 197 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 198 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 199 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 200 |
+
return df
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 204 |
+
"""
|
| 205 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ CTGAN๋ง ์ ์ฉํ์ฌ ์ฆ๊ฐ
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 209 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 210 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 211 |
+
"""
|
| 212 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 213 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 214 |
+
|
| 215 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 216 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 217 |
+
|
| 218 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ multi_class ์ถ๊ฐ
|
| 219 |
+
train_data = X.copy()
|
| 220 |
+
train_data['multi_class'] = y
|
| 221 |
+
|
| 222 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 223 |
+
categorical_features = get_categorical_feature_names(train_data)
|
| 224 |
+
|
| 225 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 226 |
+
count_class_0 = (y == 0).sum()
|
| 227 |
+
count_class_1 = (y == 1).sum()
|
| 228 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - count_class_1
|
| 229 |
+
|
| 230 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 232 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 233 |
+
train_data, 0, categorical_features,
|
| 234 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 238 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 239 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 240 |
+
train_data, 1, categorical_features,
|
| 241 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 245 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 248 |
+
model_path_0 = model_save_dir / f'ctgan_only_7000_3_{region_name}_class0.pkl'
|
| 249 |
+
ctgan_model_0.save(str(model_path_0))
|
| 250 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 251 |
+
|
| 252 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 253 |
+
model_path_1 = model_save_dir / f'ctgan_only_7000_3_{region_name}_class1.pkl'
|
| 254 |
+
ctgan_model_1.save(str(model_path_1))
|
| 255 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 256 |
+
|
| 257 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 258 |
+
well_generated_0 = generated_0[
|
| 259 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 260 |
+
]
|
| 261 |
+
well_generated_1 = generated_1[
|
| 262 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (CTGAN์ผ๋ก ์์ฑ๋ ์ํ๋ง)
|
| 266 |
+
augmented_only = pd.concat([well_generated_0, well_generated_1], axis=0)
|
| 267 |
+
augmented_only = add_derived_features(augmented_only)
|
| 268 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 269 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 270 |
+
output_path_obj = Path(output_path)
|
| 271 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 272 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 274 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 275 |
+
|
| 276 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ
|
| 277 |
+
ctgan_data = pd.concat([train_data, well_generated_0, well_generated_1], axis=0)
|
| 278 |
+
|
| 279 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 280 |
+
ctgan_data = add_derived_features(ctgan_data)
|
| 281 |
+
|
| 282 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 283 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 284 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 285 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 288 |
+
filtered_data = ctgan_data[ctgan_data['multi_class'] != 2]
|
| 289 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 290 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 291 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 294 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
|
| 296 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 297 |
+
final_data.to_csv(output_path, index=False)
|
| 298 |
+
|
| 299 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 300 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 301 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 302 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 303 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
setup_environment()
|
| 310 |
+
|
| 311 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 312 |
+
output_paths = [f'../../../data/data_oversampled/ctgan7000/ctgan7000_3_{region}.csv' for region in REGIONS]
|
| 313 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 314 |
+
|
| 315 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 316 |
+
process_region(file_path, output_path, model_save_dir)
|
| 317 |
+
|
Analysis_code/2.make_oversample_data/run_ctgan_gpu0.bash
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPU 0๋ฒ์์ CTGAN ์ํ ์์ฑ ์คํฌ๋ฆฝํธ๋ค์ ์์ฐจ์ ์ผ๋ก ์คํ
|
| 4 |
+
# ์คํ ๋๋ ํ ๋ฆฌ: /workspace/visibility_prediction/Analysis_code/make_oversample_data
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 7 |
+
|
| 8 |
+
echo "=========================================="
|
| 9 |
+
echo "Starting CTGAN sample generation on GPU 0"
|
| 10 |
+
echo "=========================================="
|
| 11 |
+
echo ""
|
| 12 |
+
|
| 13 |
+
# 7000 ์ํ ์์ฑ
|
| 14 |
+
echo "=== Processing 7000 samples ==="
|
| 15 |
+
echo "Running only_ctgan/ctgan_sample_7000_1.py..."
|
| 16 |
+
python only_ctgan/ctgan_sample_7000_1.py
|
| 17 |
+
echo ""
|
| 18 |
+
|
| 19 |
+
echo "Running only_ctgan/ctgan_sample_7000_2.py..."
|
| 20 |
+
python only_ctgan/ctgan_sample_7000_2.py
|
| 21 |
+
echo ""
|
| 22 |
+
|
| 23 |
+
echo "Running only_ctgan/ctgan_sample_7000_3.py..."
|
| 24 |
+
python only_ctgan/ctgan_sample_7000_3.py
|
| 25 |
+
echo ""
|
| 26 |
+
|
| 27 |
+
# 10000 ์ํ ์์ฑ
|
| 28 |
+
echo "=== Processing 10000 samples ==="
|
| 29 |
+
echo "Running only_ctgan/ctgan_sample_10000_1.py..."
|
| 30 |
+
python only_ctgan/ctgan_sample_10000_1.py
|
| 31 |
+
echo ""
|
| 32 |
+
|
| 33 |
+
echo "Running only_ctgan/ctgan_sample_10000_2.py..."
|
| 34 |
+
python only_ctgan/ctgan_sample_10000_2.py
|
| 35 |
+
echo ""
|
| 36 |
+
|
| 37 |
+
echo "Running only_ctgan/ctgan_sample_10000_3.py..."
|
| 38 |
+
python only_ctgan/ctgan_sample_10000_3.py
|
| 39 |
+
echo ""
|
| 40 |
+
|
| 41 |
+
# 20000 ์ํ ์์ฑ
|
| 42 |
+
echo "=== Processing 20000 samples ==="
|
| 43 |
+
echo "Running only_ctgan/ctgan_sample_20000_1.py..."
|
| 44 |
+
python only_ctgan/ctgan_sample_20000_1.py
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
echo "Running only_ctgan/ctgan_sample_20000_2.py..."
|
| 48 |
+
python only_ctgan/ctgan_sample_20000_2.py
|
| 49 |
+
echo ""
|
| 50 |
+
|
| 51 |
+
echo "Running only_ctgan/ctgan_sample_20000_3.py..."
|
| 52 |
+
python only_ctgan/ctgan_sample_20000_3.py
|
| 53 |
+
echo ""
|
| 54 |
+
|
| 55 |
+
echo "=========================================="
|
| 56 |
+
echo "All CTGAN sample generation completed!"
|
| 57 |
+
echo "=========================================="
|
| 58 |
+
|
Analysis_code/2.make_oversample_data/run_ctgan_gpu1.bash
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# GPU 1๋ฒ์์ SMOTENC+CTGAN ์ํ ์์ฑ ์คํฌ๋ฆฝํธ๋ค์ ์์ฐจ์ ์ผ๋ก ์คํ
|
| 4 |
+
# ์คํ ๋๋ ํ ๋ฆฌ: /workspace/visibility_prediction/Analysis_code/make_oversample_data
|
| 5 |
+
|
| 6 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 7 |
+
|
| 8 |
+
echo "=========================================="
|
| 9 |
+
echo "Starting SMOTENC+CTGAN sample generation on GPU 1"
|
| 10 |
+
echo "=========================================="
|
| 11 |
+
echo ""
|
| 12 |
+
|
| 13 |
+
# 7000 ์ํ ์์ฑ
|
| 14 |
+
echo "=== Processing 7000 samples ==="
|
| 15 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_7000_1.py..."
|
| 16 |
+
python smotenc_ctgan/smotenc_ctgan_sample_7000_1.py
|
| 17 |
+
echo ""
|
| 18 |
+
|
| 19 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_7000_2.py..."
|
| 20 |
+
python smotenc_ctgan/smotenc_ctgan_sample_7000_2.py
|
| 21 |
+
echo ""
|
| 22 |
+
|
| 23 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_7000_3.py..."
|
| 24 |
+
python smotenc_ctgan/smotenc_ctgan_sample_7000_3.py
|
| 25 |
+
echo ""
|
| 26 |
+
|
| 27 |
+
# 10000 ์ํ ์์ฑ
|
| 28 |
+
echo "=== Processing 10000 samples ==="
|
| 29 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_10000_1.py..."
|
| 30 |
+
python smotenc_ctgan/smotenc_ctgan_sample_10000_1.py
|
| 31 |
+
echo ""
|
| 32 |
+
|
| 33 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_10000_2.py..."
|
| 34 |
+
python smotenc_ctgan/smotenc_ctgan_sample_10000_2.py
|
| 35 |
+
echo ""
|
| 36 |
+
|
| 37 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_10000_3.py..."
|
| 38 |
+
python smotenc_ctgan/smotenc_ctgan_sample_10000_3.py
|
| 39 |
+
echo ""
|
| 40 |
+
|
| 41 |
+
# 20000 ์ํ ์์ฑ
|
| 42 |
+
echo "=== Processing 20000 samples ==="
|
| 43 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_20000_1.py..."
|
| 44 |
+
python smotenc_ctgan/smotenc_ctgan_sample_20000_1.py
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_20000_2.py..."
|
| 48 |
+
python smotenc_ctgan/smotenc_ctgan_sample_20000_2.py
|
| 49 |
+
echo ""
|
| 50 |
+
|
| 51 |
+
echo "Running smotenc_ctgan/smotenc_ctgan_sample_20000_3.py..."
|
| 52 |
+
python smotenc_ctgan/smotenc_ctgan_sample_20000_3.py
|
| 53 |
+
echo ""
|
| 54 |
+
|
| 55 |
+
echo "=========================================="
|
| 56 |
+
echo "All SMOTENC+CTGAN sample generation completed!"
|
| 57 |
+
echo "=========================================="
|
| 58 |
+
|
Analysis_code/2.make_oversample_data/smote_only/smote_sample_1.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
|
| 7 |
+
# ์ง์ญ๋ณ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 8 |
+
regions = ['incheon', 'seoul','busan', 'daegu', 'daejeon', 'gwangju']
|
| 9 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in regions]
|
| 10 |
+
output_paths = [f'../../../data/data_oversampled/smote/smote_1_{region}.csv' for region in regions]
|
| 11 |
+
|
| 12 |
+
# ์ง์ญ๋ณ ์ฒ๋ฆฌ
|
| 13 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 14 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 15 |
+
original_data = pd.read_csv(file_path, index_col=0)
|
| 16 |
+
data = original_data.loc[original_data['year'].isin([2018, 2019]), :]
|
| 17 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 18 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 19 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 20 |
+
y = data['multi_class']
|
| 21 |
+
|
| 22 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 23 |
+
X.drop(columns=['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos'], inplace=True)
|
| 24 |
+
|
| 25 |
+
# SMOTENC์์ ์ฌ์ฉํ ๋ฒ์ฃผํ ๋ณ์ ์ด ๋ฒํธ ์ค์
|
| 26 |
+
categorical_features_indices = [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 27 |
+
|
| 28 |
+
# sampling_strategy ์ค์
|
| 29 |
+
count_class_2 = (y == 2).sum()
|
| 30 |
+
sampling_strategy = {
|
| 31 |
+
0: int(np.ceil(count_class_2 / 1000) * 500),
|
| 32 |
+
1: int(np.ceil(count_class_2 / 1000) * 500),
|
| 33 |
+
2: count_class_2
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# SMOTENC ์ ์ฉ
|
| 37 |
+
smotenc = SMOTENC(categorical_features=categorical_features_indices, sampling_strategy=sampling_strategy, random_state=42)
|
| 38 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 39 |
+
|
| 40 |
+
# Resampled ๋ฐ์ดํฐ ์์ฑ
|
| 41 |
+
lerp_data = X_resampled.copy()
|
| 42 |
+
lerp_data['multi_class'] = y_resampled
|
| 43 |
+
|
| 44 |
+
# ์ ๊ฑฐ๋ณ์ ๋ณต๊ตฌ
|
| 45 |
+
lerp_data['binary_class'] = lerp_data['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 46 |
+
lerp_data['hour_sin'] = np.sin(2 * np.pi * lerp_data['hour'] / 24)
|
| 47 |
+
lerp_data['hour_cos'] = np.cos(2 * np.pi * lerp_data['hour'] / 24)
|
| 48 |
+
lerp_data['month_sin'] = np.sin(2 * np.pi * lerp_data['month'] / 12)
|
| 49 |
+
lerp_data['month_cos'] = np.cos(2 * np.pi * lerp_data['month'] / 12)
|
| 50 |
+
lerp_data['ground_temp - temp_C'] = lerp_data['groundtemp'] - lerp_data['temp_C']
|
| 51 |
+
|
| 52 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง)
|
| 53 |
+
# lerp_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 54 |
+
original_data_count = len(X)
|
| 55 |
+
augmented_only = lerp_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 56 |
+
augmented_only = augmented_only[augmented_only['multi_class'] != 2].copy() # ํด๋์ค 2 ์ ์ธ
|
| 57 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 58 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 59 |
+
output_path_obj = Path(output_path)
|
| 60 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 61 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 63 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 64 |
+
|
| 65 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 66 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 67 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 68 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 69 |
+
|
| 70 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 71 |
+
filtered_data = lerp_data[lerp_data['multi_class'] != 2]
|
| 72 |
+
original_class_2 = data[data['multi_class'] == 2]
|
| 73 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 74 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 75 |
+
|
| 76 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 77 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 80 |
+
final_data.to_csv(output_path, index=False)
|
| 81 |
+
|
| 82 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 83 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 84 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 85 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 86 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
Analysis_code/2.make_oversample_data/smote_only/smote_sample_2.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
|
| 7 |
+
# ์ง์ญ๋ณ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 8 |
+
regions = ['incheon', 'seoul','busan', 'daegu', 'daejeon', 'gwangju']
|
| 9 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in regions]
|
| 10 |
+
output_paths = [f'../../../data/data_oversampled/smote/smote_2_{region}.csv' for region in regions]
|
| 11 |
+
|
| 12 |
+
# ์ง์ญ๋ณ ์ฒ๋ฆฌ
|
| 13 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 14 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 15 |
+
original_data = pd.read_csv(file_path, index_col=0)
|
| 16 |
+
data = original_data.loc[original_data['year'].isin([2018, 2020]), :]
|
| 17 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 18 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 19 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 20 |
+
y = data['multi_class']
|
| 21 |
+
|
| 22 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 23 |
+
X.drop(columns=['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos'], inplace=True)
|
| 24 |
+
|
| 25 |
+
# SMOTENC์์ ์ฌ์ฉํ ๋ฒ์ฃผํ ๋ณ์ ์ด ๋ฒํธ ์ค์
|
| 26 |
+
categorical_features_indices = [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 27 |
+
|
| 28 |
+
# sampling_strategy ์ค์
|
| 29 |
+
count_class_2 = (y == 2).sum()
|
| 30 |
+
sampling_strategy = {
|
| 31 |
+
0: int(np.ceil(count_class_2 / 1000) * 500),
|
| 32 |
+
1: int(np.ceil(count_class_2 / 1000) * 500),
|
| 33 |
+
2: count_class_2
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# SMOTENC ์ ์ฉ
|
| 37 |
+
smotenc = SMOTENC(categorical_features=categorical_features_indices, sampling_strategy=sampling_strategy, random_state=42)
|
| 38 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 39 |
+
|
| 40 |
+
# Resampled ๋ฐ์ดํฐ ์์ฑ
|
| 41 |
+
lerp_data = X_resampled.copy()
|
| 42 |
+
lerp_data['multi_class'] = y_resampled
|
| 43 |
+
|
| 44 |
+
# ์ ๊ฑฐ๋ณ์ ๋ณต๊ตฌ
|
| 45 |
+
lerp_data['binary_class'] = lerp_data['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 46 |
+
lerp_data['hour_sin'] = np.sin(2 * np.pi * lerp_data['hour'] / 24)
|
| 47 |
+
lerp_data['hour_cos'] = np.cos(2 * np.pi * lerp_data['hour'] / 24)
|
| 48 |
+
lerp_data['month_sin'] = np.sin(2 * np.pi * lerp_data['month'] / 12)
|
| 49 |
+
lerp_data['month_cos'] = np.cos(2 * np.pi * lerp_data['month'] / 12)
|
| 50 |
+
lerp_data['ground_temp - temp_C'] = lerp_data['groundtemp'] - lerp_data['temp_C']
|
| 51 |
+
|
| 52 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง)
|
| 53 |
+
# lerp_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 54 |
+
original_data_count = len(X)
|
| 55 |
+
augmented_only = lerp_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 56 |
+
augmented_only = augmented_only[augmented_only['multi_class'] != 2].copy() # ํด๋์ค 2 ์ ์ธ
|
| 57 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 58 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 59 |
+
output_path_obj = Path(output_path)
|
| 60 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 61 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 63 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 64 |
+
|
| 65 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 66 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 67 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 68 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 69 |
+
|
| 70 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 71 |
+
filtered_data = lerp_data[lerp_data['multi_class'] != 2]
|
| 72 |
+
original_class_2 = data[data['multi_class'] == 2]
|
| 73 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 74 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 75 |
+
|
| 76 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 77 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 80 |
+
final_data.to_csv(output_path, index=False)
|
| 81 |
+
|
| 82 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 83 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 84 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 85 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 86 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
Analysis_code/2.make_oversample_data/smote_only/smote_sample_3.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
|
| 7 |
+
# ์ง์ญ๋ณ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 8 |
+
regions = ['incheon', 'seoul','busan', 'daegu', 'daejeon', 'gwangju']
|
| 9 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in regions]
|
| 10 |
+
output_paths = [f'../../../data/data_oversampled/smote/smote_3_{region}.csv' for region in regions]
|
| 11 |
+
|
| 12 |
+
# ์ง์ญ๋ณ ์ฒ๋ฆฌ
|
| 13 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 14 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 15 |
+
original_data = pd.read_csv(file_path, index_col=0)
|
| 16 |
+
data = original_data.loc[original_data['year'].isin([2019, 2020]), :]
|
| 17 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 18 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 19 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 20 |
+
y = data['multi_class']
|
| 21 |
+
|
| 22 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 23 |
+
X.drop(columns=['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos'], inplace=True)
|
| 24 |
+
|
| 25 |
+
# SMOTENC์์ ์ฌ์ฉํ ๋ฒ์ฃผํ ๋ณ์ ์ด ๋ฒํธ ์ค์
|
| 26 |
+
categorical_features_indices = [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 27 |
+
|
| 28 |
+
# sampling_strategy ์ค์
|
| 29 |
+
count_class_2 = (y == 2).sum()
|
| 30 |
+
sampling_strategy = {
|
| 31 |
+
0: int(np.ceil(count_class_2 / 1000) * 500),
|
| 32 |
+
1: int(np.ceil(count_class_2 / 1000) * 500),
|
| 33 |
+
2: count_class_2
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# SMOTENC ์ ์ฉ
|
| 37 |
+
smotenc = SMOTENC(categorical_features=categorical_features_indices, sampling_strategy=sampling_strategy, random_state=42)
|
| 38 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 39 |
+
|
| 40 |
+
# Resampled ๋ฐ์ดํฐ ์์ฑ
|
| 41 |
+
lerp_data = X_resampled.copy()
|
| 42 |
+
lerp_data['multi_class'] = y_resampled
|
| 43 |
+
|
| 44 |
+
# ์ ๊ฑฐ๋ณ์ ๋ณต๊ตฌ
|
| 45 |
+
lerp_data['binary_class'] = lerp_data['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 46 |
+
lerp_data['hour_sin'] = np.sin(2 * np.pi * lerp_data['hour'] / 24)
|
| 47 |
+
lerp_data['hour_cos'] = np.cos(2 * np.pi * lerp_data['hour'] / 24)
|
| 48 |
+
lerp_data['month_sin'] = np.sin(2 * np.pi * lerp_data['month'] / 12)
|
| 49 |
+
lerp_data['month_cos'] = np.cos(2 * np.pi * lerp_data['month'] / 12)
|
| 50 |
+
lerp_data['ground_temp - temp_C'] = lerp_data['groundtemp'] - lerp_data['temp_C']
|
| 51 |
+
|
| 52 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ ์ฅ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง)
|
| 53 |
+
# lerp_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 54 |
+
original_data_count = len(X)
|
| 55 |
+
augmented_only = lerp_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 56 |
+
augmented_only = augmented_only[augmented_only['multi_class'] != 2].copy() # ํด๋์ค 2 ์ ์ธ
|
| 57 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 58 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 59 |
+
output_path_obj = Path(output_path)
|
| 60 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 61 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 63 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 64 |
+
|
| 65 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 66 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 67 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 68 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 69 |
+
|
| 70 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 71 |
+
filtered_data = lerp_data[lerp_data['multi_class'] != 2]
|
| 72 |
+
original_class_2 = data[data['multi_class'] == 2]
|
| 73 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 74 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 75 |
+
|
| 76 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 77 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 80 |
+
final_data.to_csv(output_path, index=False)
|
| 81 |
+
|
| 82 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 83 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 84 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 85 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 86 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_1.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from imblearn.over_sampling import SMOTENC
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2019]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 80 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 84 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 85 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 89 |
+
"""
|
| 90 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
y: ํ๊ฒ ๋ณ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 97 |
+
"""
|
| 98 |
+
count_class_0 = (y == 0).sum()
|
| 99 |
+
count_class_1 = (y == 1).sum()
|
| 100 |
+
count_class_2 = (y == 2).sum()
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 104 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 105 |
+
2: count_class_2
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 110 |
+
categorical_features_indices: list,
|
| 111 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 112 |
+
"""
|
| 113 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 117 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 118 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 119 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 123 |
+
"""
|
| 124 |
+
smotenc = SMOTENC(
|
| 125 |
+
categorical_features=categorical_features_indices,
|
| 126 |
+
sampling_strategy=sampling_strategy,
|
| 127 |
+
random_state=RANDOM_STATE
|
| 128 |
+
)
|
| 129 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 130 |
+
|
| 131 |
+
resampled_data = X_resampled.copy()
|
| 132 |
+
resampled_data['multi_class'] = y_resampled
|
| 133 |
+
|
| 134 |
+
return resampled_data
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 138 |
+
categorical_features: list,
|
| 139 |
+
hp_ranges: dict) -> callable:
|
| 140 |
+
"""
|
| 141 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 145 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 146 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 147 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Optuna ๋ชฉ์ ํจ์
|
| 151 |
+
"""
|
| 152 |
+
class_data = data[data['multi_class'] == class_label]
|
| 153 |
+
|
| 154 |
+
def objective(trial):
|
| 155 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 156 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 157 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 158 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 159 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 160 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 161 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 162 |
+
|
| 163 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 164 |
+
ctgan = CTGAN(
|
| 165 |
+
embedding_dim=embedding_dim,
|
| 166 |
+
generator_dim=generator_dim,
|
| 167 |
+
discriminator_dim=discriminator_dim,
|
| 168 |
+
batch_size=batch_size,
|
| 169 |
+
discriminator_steps=discriminator_steps,
|
| 170 |
+
pac=pac
|
| 171 |
+
)
|
| 172 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 173 |
+
|
| 174 |
+
# ๋ชจ๋ธ ํ์ต
|
| 175 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 176 |
+
|
| 177 |
+
# ์ํ ์์ฑ
|
| 178 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 179 |
+
|
| 180 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 181 |
+
real_visi = class_data['visi']
|
| 182 |
+
generated_visi = generated_data['visi']
|
| 183 |
+
|
| 184 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 185 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 186 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 187 |
+
return -mse
|
| 188 |
+
|
| 189 |
+
return objective
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 193 |
+
categorical_features: list,
|
| 194 |
+
hp_ranges: dict, n_trials: int,
|
| 195 |
+
target_samples: int) -> tuple:
|
| 196 |
+
"""
|
| 197 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 201 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 202 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 203 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 204 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 205 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 209 |
+
"""
|
| 210 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 211 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 212 |
+
|
| 213 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 214 |
+
study = optuna.create_study(direction="maximize")
|
| 215 |
+
study.optimize(objective, n_trials=n_trials)
|
| 216 |
+
|
| 217 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 218 |
+
best_params = study.best_params
|
| 219 |
+
ctgan = CTGAN(
|
| 220 |
+
embedding_dim=best_params["embedding_dim"],
|
| 221 |
+
generator_dim=best_params["generator_dim"],
|
| 222 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 223 |
+
batch_size=best_params["batch_size"],
|
| 224 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 225 |
+
pac=best_params["pac"]
|
| 226 |
+
)
|
| 227 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 228 |
+
|
| 229 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 230 |
+
class_data = data[data['multi_class'] == class_label]
|
| 231 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 232 |
+
generated_samples = ctgan.sample(target_samples)
|
| 233 |
+
|
| 234 |
+
return generated_samples, ctgan
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 238 |
+
"""
|
| 239 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 246 |
+
"""
|
| 247 |
+
df = df.copy()
|
| 248 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 249 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 250 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 252 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 254 |
+
return df
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 258 |
+
"""
|
| 259 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 263 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 265 |
+
"""
|
| 266 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 267 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 268 |
+
|
| 269 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 270 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 271 |
+
|
| 272 |
+
# SMOTENC ์ ์ฉ
|
| 273 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 274 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 275 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 276 |
+
|
| 277 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 278 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 279 |
+
|
| 280 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 281 |
+
count_class_1 = (y == 1).sum()
|
| 282 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 283 |
+
|
| 284 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 285 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 286 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 287 |
+
smotenc_data, 0, categorical_features,
|
| 288 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 292 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 293 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 294 |
+
smotenc_data, 1, categorical_features,
|
| 295 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 299 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 300 |
+
|
| 301 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 302 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_10000_1_{region_name}_class0.pkl'
|
| 303 |
+
ctgan_model_0.save(str(model_path_0))
|
| 304 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 305 |
+
|
| 306 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 307 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_10000_1_{region_name}_class1.pkl'
|
| 308 |
+
ctgan_model_1.save(str(model_path_1))
|
| 309 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 310 |
+
|
| 311 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 312 |
+
well_generated_0 = generated_0[
|
| 313 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 314 |
+
]
|
| 315 |
+
well_generated_1 = generated_1[
|
| 316 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 320 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 321 |
+
original_data_count = len(X)
|
| 322 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 323 |
+
|
| 324 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 325 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 326 |
+
augmented_only = add_derived_features(augmented_only)
|
| 327 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 328 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 329 |
+
output_path_obj = Path(output_path)
|
| 330 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 331 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 332 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 333 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 334 |
+
|
| 335 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 336 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 337 |
+
|
| 338 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 339 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 340 |
+
|
| 341 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 342 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 343 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 344 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 345 |
+
|
| 346 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 347 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 348 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 349 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 350 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 351 |
+
|
| 352 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 353 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 354 |
+
|
| 355 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 356 |
+
final_data.to_csv(output_path, index=False)
|
| 357 |
+
|
| 358 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 359 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 360 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 361 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 362 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
setup_environment()
|
| 369 |
+
|
| 370 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 371 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan10000/smotenc_ctgan10000_1_{region}.csv' for region in REGIONS]
|
| 372 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 373 |
+
|
| 374 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 375 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_2.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2018, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๏ฟฝ๏ฟฝ๏ฟฝ์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_10000_2_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_10000_2_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan10000/smotenc_ctgan10000_2_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_10000_3.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2019, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 10000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 10000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๏ฟฝ๏ฟฝ๏ฟฝ์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_10000_3_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_10000_3_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan10000/smotenc_ctgan10000_3_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_1.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from imblearn.over_sampling import SMOTENC
|
| 5 |
+
import optuna
|
| 6 |
+
from ctgan import CTGAN
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
# ==================== ์์ ์ ์ ====================
|
| 11 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 12 |
+
TRAIN_YEARS = [2018, 2019]
|
| 13 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 14 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 15 |
+
RANDOM_STATE = 42
|
| 16 |
+
|
| 17 |
+
# Optuna ์ต์ ํ ์ค์
|
| 18 |
+
CLASS_0_TRIALS = 50
|
| 19 |
+
CLASS_1_TRIALS = 30
|
| 20 |
+
|
| 21 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 22 |
+
CLASS_0_HP_RANGES = {
|
| 23 |
+
'embedding_dim': (64, 128),
|
| 24 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 25 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'pac': [4, 8],
|
| 27 |
+
'batch_size': [64, 128, 256],
|
| 28 |
+
'discriminator_steps': (1, 3)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
CLASS_1_HP_RANGES = {
|
| 32 |
+
'embedding_dim': (128, 512),
|
| 33 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 34 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'pac': [4, 8],
|
| 36 |
+
'batch_size': [256, 512, 1024],
|
| 37 |
+
'discriminator_steps': (1, 5)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 41 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 42 |
+
|
| 43 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 44 |
+
|
| 45 |
+
def setup_environment():
|
| 46 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 50 |
+
return device
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 54 |
+
"""
|
| 55 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 59 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 63 |
+
"""
|
| 64 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 65 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 66 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 67 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 68 |
+
|
| 69 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 70 |
+
y = data['multi_class']
|
| 71 |
+
|
| 72 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 73 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 74 |
+
|
| 75 |
+
return data, X, y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 79 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 80 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 84 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 85 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 89 |
+
"""
|
| 90 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
y: ํ๊ฒ ๋ณ์
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 97 |
+
"""
|
| 98 |
+
count_class_0 = (y == 0).sum()
|
| 99 |
+
count_class_1 = (y == 1).sum()
|
| 100 |
+
count_class_2 = (y == 2).sum()
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 104 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 105 |
+
2: count_class_2
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 110 |
+
categorical_features_indices: list,
|
| 111 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 112 |
+
"""
|
| 113 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 117 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 118 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 119 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 123 |
+
"""
|
| 124 |
+
smotenc = SMOTENC(
|
| 125 |
+
categorical_features=categorical_features_indices,
|
| 126 |
+
sampling_strategy=sampling_strategy,
|
| 127 |
+
random_state=RANDOM_STATE
|
| 128 |
+
)
|
| 129 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 130 |
+
|
| 131 |
+
resampled_data = X_resampled.copy()
|
| 132 |
+
resampled_data['multi_class'] = y_resampled
|
| 133 |
+
|
| 134 |
+
return resampled_data
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 138 |
+
categorical_features: list,
|
| 139 |
+
hp_ranges: dict) -> callable:
|
| 140 |
+
"""
|
| 141 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 145 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 146 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 147 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Optuna ๋ชฉ์ ํจ์
|
| 151 |
+
"""
|
| 152 |
+
class_data = data[data['multi_class'] == class_label]
|
| 153 |
+
|
| 154 |
+
def objective(trial):
|
| 155 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 156 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 157 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 158 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 159 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 160 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 161 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 162 |
+
|
| 163 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 164 |
+
ctgan = CTGAN(
|
| 165 |
+
embedding_dim=embedding_dim,
|
| 166 |
+
generator_dim=generator_dim,
|
| 167 |
+
discriminator_dim=discriminator_dim,
|
| 168 |
+
batch_size=batch_size,
|
| 169 |
+
discriminator_steps=discriminator_steps,
|
| 170 |
+
pac=pac
|
| 171 |
+
)
|
| 172 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 173 |
+
|
| 174 |
+
# ๋ชจ๋ธ ํ์ต
|
| 175 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 176 |
+
|
| 177 |
+
# ์ํ ์์ฑ
|
| 178 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 179 |
+
|
| 180 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 181 |
+
real_visi = class_data['visi']
|
| 182 |
+
generated_visi = generated_data['visi']
|
| 183 |
+
|
| 184 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 185 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 186 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 187 |
+
return -mse
|
| 188 |
+
|
| 189 |
+
return objective
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 193 |
+
categorical_features: list,
|
| 194 |
+
hp_ranges: dict, n_trials: int,
|
| 195 |
+
target_samples: int) -> tuple:
|
| 196 |
+
"""
|
| 197 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 201 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 202 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 203 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 204 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 205 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 209 |
+
"""
|
| 210 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 211 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 212 |
+
|
| 213 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 214 |
+
study = optuna.create_study(direction="maximize")
|
| 215 |
+
study.optimize(objective, n_trials=n_trials)
|
| 216 |
+
|
| 217 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 218 |
+
best_params = study.best_params
|
| 219 |
+
ctgan = CTGAN(
|
| 220 |
+
embedding_dim=best_params["embedding_dim"],
|
| 221 |
+
generator_dim=best_params["generator_dim"],
|
| 222 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 223 |
+
batch_size=best_params["batch_size"],
|
| 224 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 225 |
+
pac=best_params["pac"]
|
| 226 |
+
)
|
| 227 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 228 |
+
|
| 229 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 230 |
+
class_data = data[data['multi_class'] == class_label]
|
| 231 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 232 |
+
generated_samples = ctgan.sample(target_samples)
|
| 233 |
+
|
| 234 |
+
return generated_samples, ctgan
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 238 |
+
"""
|
| 239 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 246 |
+
"""
|
| 247 |
+
df = df.copy()
|
| 248 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 249 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 250 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 252 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 254 |
+
return df
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 258 |
+
"""
|
| 259 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 263 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 265 |
+
"""
|
| 266 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 267 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 268 |
+
|
| 269 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 270 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 271 |
+
|
| 272 |
+
# SMOTENC ์ ์ฉ
|
| 273 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 274 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 275 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 276 |
+
|
| 277 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 278 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 279 |
+
|
| 280 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 281 |
+
count_class_1 = (y == 1).sum()
|
| 282 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 283 |
+
|
| 284 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 285 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 286 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 287 |
+
smotenc_data, 0, categorical_features,
|
| 288 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 292 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 293 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 294 |
+
smotenc_data, 1, categorical_features,
|
| 295 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 299 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 300 |
+
|
| 301 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 302 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_20000_1_{region_name}_class0.pkl'
|
| 303 |
+
ctgan_model_0.save(str(model_path_0))
|
| 304 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 305 |
+
|
| 306 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 307 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_20000_1_{region_name}_class1.pkl'
|
| 308 |
+
ctgan_model_1.save(str(model_path_1))
|
| 309 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 310 |
+
|
| 311 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 312 |
+
well_generated_0 = generated_0[
|
| 313 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 314 |
+
]
|
| 315 |
+
well_generated_1 = generated_1[
|
| 316 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 320 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 321 |
+
original_data_count = len(X)
|
| 322 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 323 |
+
|
| 324 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 325 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 326 |
+
augmented_only = add_derived_features(augmented_only)
|
| 327 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 328 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 329 |
+
output_path_obj = Path(output_path)
|
| 330 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 331 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 332 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 333 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 334 |
+
|
| 335 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 336 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 337 |
+
|
| 338 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 339 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 340 |
+
|
| 341 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 342 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 343 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 344 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 345 |
+
|
| 346 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 347 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 348 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 349 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 350 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 351 |
+
|
| 352 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 353 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 354 |
+
|
| 355 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 356 |
+
final_data.to_csv(output_path, index=False)
|
| 357 |
+
|
| 358 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 359 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 360 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 361 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 362 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
setup_environment()
|
| 369 |
+
|
| 370 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 371 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan20000/smotenc_ctgan20000_1_{region}.csv' for region in REGIONS]
|
| 372 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 373 |
+
|
| 374 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 375 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_2.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2018, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๏ฟฝ๏ฟฝ๏ฟฝ์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_20000_2_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_20000_2_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan20000/smotenc_ctgan20000_2_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_20000_3.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2019, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 20000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 20000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๏ฟฝ๏ฟฝ๏ฟฝ์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_20000_3_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_20000_3_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan20000/smotenc_ctgan20000_3_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_1.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2018, 2019]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 261 |
+
"""
|
| 262 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 266 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 267 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 268 |
+
"""
|
| 269 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 270 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 271 |
+
|
| 272 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 273 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 274 |
+
|
| 275 |
+
# SMOTENC ์ ์ฉ
|
| 276 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 277 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 278 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 279 |
+
|
| 280 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 281 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 282 |
+
|
| 283 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 284 |
+
count_class_1 = (y == 1).sum()
|
| 285 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 286 |
+
|
| 287 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 288 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 289 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 290 |
+
smotenc_data, 0, categorical_features,
|
| 291 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# ๏ฟฝ๏ฟฝ๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 295 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 296 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 297 |
+
smotenc_data, 1, categorical_features,
|
| 298 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 302 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 303 |
+
|
| 304 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 305 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_7000_1_{region_name}_class0.pkl'
|
| 306 |
+
ctgan_model_0.save(str(model_path_0))
|
| 307 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 308 |
+
|
| 309 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 310 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_7000_1_{region_name}_class1.pkl'
|
| 311 |
+
ctgan_model_1.save(str(model_path_1))
|
| 312 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 313 |
+
|
| 314 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 315 |
+
well_generated_0 = generated_0[
|
| 316 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 317 |
+
]
|
| 318 |
+
well_generated_1 = generated_1[
|
| 319 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 320 |
+
]
|
| 321 |
+
|
| 322 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 323 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 324 |
+
original_data_count = len(X)
|
| 325 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 326 |
+
|
| 327 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 328 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 329 |
+
augmented_only = add_derived_features(augmented_only)
|
| 330 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 331 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 332 |
+
output_path_obj = Path(output_path)
|
| 333 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 334 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 335 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 336 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 337 |
+
|
| 338 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 339 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 340 |
+
|
| 341 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 342 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 343 |
+
|
| 344 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 345 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 346 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 347 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 348 |
+
|
| 349 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 350 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 351 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 352 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 353 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 354 |
+
|
| 355 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 356 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 357 |
+
|
| 358 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 359 |
+
final_data.to_csv(output_path, index=False)
|
| 360 |
+
|
| 361 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 362 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 363 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 364 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 365 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
setup_environment()
|
| 372 |
+
|
| 373 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 374 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan7000/smotenc_ctgan7000_1_{region}.csv' for region in REGIONS]
|
| 375 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 376 |
+
|
| 377 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 378 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_2.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2018, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_7000_2_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_7000_2_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan7000/smotenc_ctgan7000_2_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/2.make_oversample_data/smotenc_ctgan/smotenc_ctgan_sample_7000_3.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from imblearn.over_sampling import SMOTENC
|
| 6 |
+
import optuna
|
| 7 |
+
from ctgan import CTGAN
|
| 8 |
+
import torch
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# ==================== ์์ ์ ์ ====================
|
| 12 |
+
REGIONS = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 13 |
+
TRAIN_YEARS = [2019, 2020]
|
| 14 |
+
TARGET_SAMPLES_CLASS_0 = 7000
|
| 15 |
+
TARGET_SAMPLES_CLASS_1_BASE = 7000
|
| 16 |
+
RANDOM_STATE = 42
|
| 17 |
+
|
| 18 |
+
# Optuna ์ต์ ํ ์ค์
|
| 19 |
+
CLASS_0_TRIALS = 50
|
| 20 |
+
CLASS_1_TRIALS = 30
|
| 21 |
+
|
| 22 |
+
# ํด๋์ค๋ณ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 23 |
+
CLASS_0_HP_RANGES = {
|
| 24 |
+
'embedding_dim': (64, 128),
|
| 25 |
+
'generator_dim': [(64, 64), (128, 128)],
|
| 26 |
+
'discriminator_dim': [(64, 64), (128, 128)],
|
| 27 |
+
'pac': [4, 8],
|
| 28 |
+
'batch_size': [64, 128, 256],
|
| 29 |
+
'discriminator_steps': (1, 3)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CLASS_1_HP_RANGES = {
|
| 33 |
+
'embedding_dim': (128, 512),
|
| 34 |
+
'generator_dim': [(128, 128), (256, 256)],
|
| 35 |
+
'discriminator_dim': [(128, 128), (256, 256)],
|
| 36 |
+
'pac': [4, 8],
|
| 37 |
+
'batch_size': [256, 512, 1024],
|
| 38 |
+
'discriminator_steps': (1, 5)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# ์ ๊ฑฐํ ์ด ๋ชฉ๋ก
|
| 42 |
+
COLUMNS_TO_DROP = ['ground_temp - temp_C', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos']
|
| 43 |
+
|
| 44 |
+
# ==================== ์ ํธ๋ฆฌํฐ ํจ์ ====================
|
| 45 |
+
|
| 46 |
+
def setup_environment():
|
| 47 |
+
"""ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๊ณ ๋ฌด์)"""
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
print(f"Using device: {device}")
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="optuna.distributions")
|
| 51 |
+
return device
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_and_preprocess_data(file_path: str, train_years: list) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
file_path: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 60 |
+
train_years: ํ์ต์ ์ฌ์ฉํ ์ฐ๋ ๋ฆฌ์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
(data, X, y): ์๋ณธ ๋ฐ์ดํฐ, ํน์ง ๋ฐ์ดํฐ, ํ๊ฒ ๋ฐ์ดํฐ
|
| 64 |
+
"""
|
| 65 |
+
data = pd.read_csv(file_path, index_col=0)
|
| 66 |
+
data = data.loc[data['year'].isin(train_years), :]
|
| 67 |
+
data['cloudcover'] = data['cloudcover'].astype('int')
|
| 68 |
+
data['lm_cloudcover'] = data['lm_cloudcover'].astype('int')
|
| 69 |
+
|
| 70 |
+
X = data.drop(columns=['multi_class', 'binary_class'])
|
| 71 |
+
y = data['multi_class']
|
| 72 |
+
|
| 73 |
+
# ๋ถํ์ํ ์ด ์ ๊ฑฐ
|
| 74 |
+
X.drop(columns=COLUMNS_TO_DROP, inplace=True)
|
| 75 |
+
|
| 76 |
+
return data, X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_categorical_feature_indices(X: pd.DataFrame) -> list:
|
| 80 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ธ๋ฑ์ค ๋ฐํ"""
|
| 81 |
+
return [i for i, dtype in enumerate(X.dtypes) if dtype != 'float64']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_categorical_feature_names(df: pd.DataFrame) -> list:
|
| 85 |
+
"""๋ฒ์ฃผํ ๋ณ์์ ์ด ์ด๋ฆ ๋ฐํ"""
|
| 86 |
+
return [col for col, dtype in zip(df.columns, df.dtypes) if dtype != 'float64']
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def calculate_sampling_strategy(y: pd.Series) -> dict:
|
| 90 |
+
"""
|
| 91 |
+
SMOTENC๋ฅผ ์ํ sampling_strategy ๊ณ์ฐ
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: ํ๊ฒ ๋ณ์
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
sampling_strategy ๋์
๋๋ฆฌ
|
| 98 |
+
"""
|
| 99 |
+
count_class_0 = (y == 0).sum()
|
| 100 |
+
count_class_1 = (y == 1).sum()
|
| 101 |
+
count_class_2 = (y == 2).sum()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
0: 500 if count_class_0 <= 500 else 1000,
|
| 105 |
+
1: int(np.ceil(count_class_1 / 100) * 100), # ๋ฐฑ์ ์๋ฆฌ๋ก ์ฌ๋ฆผ
|
| 106 |
+
2: count_class_2
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def apply_smotenc(X: pd.DataFrame, y: pd.Series,
|
| 111 |
+
categorical_features_indices: list,
|
| 112 |
+
sampling_strategy: dict) -> pd.DataFrame:
|
| 113 |
+
"""
|
| 114 |
+
SMOTENC ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
X: ํน์ง ๋ฐ์ดํฐ
|
| 118 |
+
y: ํ๊ฒ ๋ฐ์ดํฐ
|
| 119 |
+
categorical_features_indices: ๋ฒ์ฃผํ ๋ณ์ ์ธ๋ฑ์ค
|
| 120 |
+
sampling_strategy: ์ํ๋ง ์ ๋ต
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
์ฆ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์ (multi_class ํฌํจ)
|
| 124 |
+
"""
|
| 125 |
+
smotenc = SMOTENC(
|
| 126 |
+
categorical_features=categorical_features_indices,
|
| 127 |
+
sampling_strategy=sampling_strategy,
|
| 128 |
+
random_state=RANDOM_STATE
|
| 129 |
+
)
|
| 130 |
+
X_resampled, y_resampled = smotenc.fit_resample(X, y)
|
| 131 |
+
|
| 132 |
+
resampled_data = X_resampled.copy()
|
| 133 |
+
resampled_data['multi_class'] = y_resampled
|
| 134 |
+
|
| 135 |
+
return resampled_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def create_ctgan_objective(data: pd.DataFrame, class_label: int,
|
| 139 |
+
categorical_features: list,
|
| 140 |
+
hp_ranges: dict) -> callable:
|
| 141 |
+
"""
|
| 142 |
+
Optuna ์ต์ ํ๋ฅผ ์ํ ๋ชฉ์ ํจ์ ์์ฑ
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 146 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 147 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 148 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Optuna ๋ชฉ์ ํจ์
|
| 152 |
+
"""
|
| 153 |
+
class_data = data[data['multi_class'] == class_label]
|
| 154 |
+
|
| 155 |
+
def objective(trial):
|
| 156 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ค์
|
| 157 |
+
embedding_dim = trial.suggest_int("embedding_dim", *hp_ranges['embedding_dim'])
|
| 158 |
+
generator_dim = trial.suggest_categorical("generator_dim", hp_ranges['generator_dim'])
|
| 159 |
+
discriminator_dim = trial.suggest_categorical("discriminator_dim", hp_ranges['discriminator_dim'])
|
| 160 |
+
pac = trial.suggest_categorical("pac", hp_ranges['pac'])
|
| 161 |
+
batch_size = trial.suggest_categorical("batch_size", hp_ranges['batch_size'])
|
| 162 |
+
discriminator_steps = trial.suggest_int("discriminator_steps", *hp_ranges['discriminator_steps'])
|
| 163 |
+
|
| 164 |
+
# CTGAN ๋ชจ๋ธ ์์ฑ
|
| 165 |
+
ctgan = CTGAN(
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
generator_dim=generator_dim,
|
| 168 |
+
discriminator_dim=discriminator_dim,
|
| 169 |
+
batch_size=batch_size,
|
| 170 |
+
discriminator_steps=discriminator_steps,
|
| 171 |
+
pac=pac
|
| 172 |
+
)
|
| 173 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 174 |
+
|
| 175 |
+
# ๋ชจ๋ธ ํ์ต
|
| 176 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 177 |
+
|
| 178 |
+
# ์ํ ์์ฑ
|
| 179 |
+
generated_data = ctgan.sample(len(class_data) * 2)
|
| 180 |
+
|
| 181 |
+
# ํ๊ฐ: ์ํ์ ์ฐ์ํ ๋ณ์ ๋ถํฌ ๋น๊ต
|
| 182 |
+
real_visi = class_data['visi']
|
| 183 |
+
generated_visi = generated_data['visi']
|
| 184 |
+
|
| 185 |
+
# ๋ถํฌ ๊ฐ ์ฐจ์ด(MSE) ๊ณ์ฐ
|
| 186 |
+
mse = ((real_visi.mean() - generated_visi.mean())**2 +
|
| 187 |
+
(real_visi.std() - generated_visi.std())**2)
|
| 188 |
+
return -mse
|
| 189 |
+
|
| 190 |
+
return objective
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def optimize_and_generate_samples(data: pd.DataFrame, class_label: int,
|
| 194 |
+
categorical_features: list,
|
| 195 |
+
hp_ranges: dict, n_trials: int,
|
| 196 |
+
target_samples: int) -> tuple:
|
| 197 |
+
"""
|
| 198 |
+
CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
data: ํ์ต ๋ฐ์ดํฐ
|
| 202 |
+
class_label: ํด๋์ค ๋ ์ด๋ธ (0 ๋๋ 1)
|
| 203 |
+
categorical_features: ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ๋ฆฌ์คํธ
|
| 204 |
+
hp_ranges: ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์
|
| 205 |
+
n_trials: Optuna ์ต์ ํ ์๋ ํ์
|
| 206 |
+
target_samples: ์์ฑํ ์ํ ์
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(์์ฑ๋ ์ํ ๋ฐ์ดํฐํ๋ ์, ํ์ต๋ CTGAN ๋ชจ๋ธ)
|
| 210 |
+
"""
|
| 211 |
+
# ๋ชฉ์ ํจ์ ์์ฑ
|
| 212 |
+
objective = create_ctgan_objective(data, class_label, categorical_features, hp_ranges)
|
| 213 |
+
|
| 214 |
+
# Optuna๋ก ์ต์ ํ ์ํ
|
| 215 |
+
study = optuna.create_study(direction="maximize")
|
| 216 |
+
study.optimize(objective, n_trials=n_trials)
|
| 217 |
+
|
| 218 |
+
# ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก CTGAN ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 219 |
+
best_params = study.best_params
|
| 220 |
+
ctgan = CTGAN(
|
| 221 |
+
embedding_dim=best_params["embedding_dim"],
|
| 222 |
+
generator_dim=best_params["generator_dim"],
|
| 223 |
+
discriminator_dim=best_params["discriminator_dim"],
|
| 224 |
+
batch_size=best_params["batch_size"],
|
| 225 |
+
discriminator_steps=best_params["discriminator_steps"],
|
| 226 |
+
pac=best_params["pac"]
|
| 227 |
+
)
|
| 228 |
+
ctgan.set_random_state(RANDOM_STATE)
|
| 229 |
+
|
| 230 |
+
# ์ต์ข
ํ์ต ๋ฐ ์ํ ์์ฑ
|
| 231 |
+
class_data = data[data['multi_class'] == class_label]
|
| 232 |
+
ctgan.fit(class_data, discrete_columns=categorical_features)
|
| 233 |
+
generated_samples = ctgan.sample(target_samples)
|
| 234 |
+
|
| 235 |
+
return generated_samples, ctgan
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 239 |
+
"""
|
| 240 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 247 |
+
"""
|
| 248 |
+
df = df.copy()
|
| 249 |
+
df['binary_class'] = df['multi_class'].apply(lambda x: 0 if x == 2 else 1)
|
| 250 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 251 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 252 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 253 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 254 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 255 |
+
return df
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_region(file_path: str, output_path: str, model_save_dir: Path) -> None:
|
| 259 |
+
"""
|
| 260 |
+
ํน์ ์ง์ญ์ ๋ฐ์ดํฐ์ SMOTENC์ CTGAN์ ์์ฐจ์ ์ผ๋ก ์ ์ฉ
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_path: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 264 |
+
output_path: ์ถ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 265 |
+
model_save_dir: ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 266 |
+
"""
|
| 267 |
+
# ์ง์ญ๋ช
์ถ์ถ (ํ์ผ ๊ฒฝ๋ก์์)
|
| 268 |
+
region_name = Path(file_path).stem.replace('_train', '')
|
| 269 |
+
|
| 270 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 271 |
+
original_data, X, y = load_and_preprocess_data(file_path, TRAIN_YEARS)
|
| 272 |
+
|
| 273 |
+
# SMOTENC ์ ์ฉ
|
| 274 |
+
categorical_features_indices = get_categorical_feature_indices(X)
|
| 275 |
+
sampling_strategy = calculate_sampling_strategy(y)
|
| 276 |
+
smotenc_data = apply_smotenc(X, y, categorical_features_indices, sampling_strategy)
|
| 277 |
+
|
| 278 |
+
# CTGAN์ ์ํ ๋ฒ์ฃผํ ๋ณ์ ์ด๋ฆ ์ถ์ถ
|
| 279 |
+
categorical_features = get_categorical_feature_names(smotenc_data)
|
| 280 |
+
|
| 281 |
+
# ํด๋์ค๋ณ ์ํ ์ ๊ณ์ฐ
|
| 282 |
+
count_class_1 = (y == 1).sum()
|
| 283 |
+
target_samples_class_1 = TARGET_SAMPLES_CLASS_1_BASE - int(np.ceil(count_class_1 / 100) * 100)
|
| 284 |
+
|
| 285 |
+
# ํด๋์ค 0์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 286 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 0...")
|
| 287 |
+
generated_0, ctgan_model_0 = optimize_and_generate_samples(
|
| 288 |
+
smotenc_data, 0, categorical_features,
|
| 289 |
+
CLASS_0_HP_RANGES, CLASS_0_TRIALS, TARGET_SAMPLES_CLASS_0
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# ํด๋์ค 1์ ๋ํ CTGAN ์ต์ ํ ๋ฐ ์ํ ์์ฑ
|
| 293 |
+
print(f"Processing {file_path}: Optimizing CTGAN for class 1...")
|
| 294 |
+
generated_1, ctgan_model_1 = optimize_and_generate_samples(
|
| 295 |
+
smotenc_data, 1, categorical_features,
|
| 296 |
+
CLASS_1_HP_RANGES, CLASS_1_TRIALS, target_samples_class_1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 300 |
+
model_save_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# ํด๋์ค 0 ๋ชจ๋ธ ์ ์ฅ
|
| 303 |
+
model_path_0 = model_save_dir / f'smotenc_ctgan_7000_3_{region_name}_class0.pkl'
|
| 304 |
+
ctgan_model_0.save(str(model_path_0))
|
| 305 |
+
print(f"Saved CTGAN model for class 0: {model_path_0}")
|
| 306 |
+
|
| 307 |
+
# ํด๋์ค 1 ๋ชจ๋ธ ์ ์ฅ
|
| 308 |
+
model_path_1 = model_save_dir / f'smotenc_ctgan_7000_3_{region_name}_class1.pkl'
|
| 309 |
+
ctgan_model_1.save(str(model_path_1))
|
| 310 |
+
print(f"Saved CTGAN model for class 1: {model_path_1}")
|
| 311 |
+
|
| 312 |
+
# ํด๋์ค๋ณ ๊ฐ์๋ ๋ฒ์๋ก ํํฐ๋ง
|
| 313 |
+
well_generated_0 = generated_0[
|
| 314 |
+
(generated_0['visi'] >= 0) & (generated_0['visi'] < 100)
|
| 315 |
+
]
|
| 316 |
+
well_generated_1 = generated_1[
|
| 317 |
+
(generated_1['visi'] >= 100) & (generated_1['visi'] < 500)
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ์ถ์ถ (SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ + CTGAN์ผ๋ก ์์ฑ๋ ์ํ)
|
| 321 |
+
# smotenc_data์ ์ฒ์ len(X)๊ฐ๋ ์๋ณธ ๋ฐ์ดํฐ์ด๋ฏ๋ก ์ ์ธ
|
| 322 |
+
original_data_count = len(X)
|
| 323 |
+
smotenc_augmented = smotenc_data.iloc[original_data_count:].copy() # SMOTENC์ผ๋ก ์ฆ๊ฐ๋ ๋ถ๋ถ๋ง
|
| 324 |
+
|
| 325 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๋ณํฉ (SMOTENC ์ฆ๊ฐ + CTGAN ์ฆ๊ฐ)
|
| 326 |
+
augmented_only = pd.concat([smotenc_augmented, well_generated_0, well_generated_1], axis=0)
|
| 327 |
+
augmented_only = add_derived_features(augmented_only)
|
| 328 |
+
augmented_only.reset_index(drop=True, inplace=True)
|
| 329 |
+
# augmented_only ํด๋์ ์ ์ฅ
|
| 330 |
+
output_path_obj = Path(output_path)
|
| 331 |
+
augmented_dir = output_path_obj.parent.parent / 'augmented_only'
|
| 332 |
+
augmented_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
augmented_output_path = augmented_dir / output_path_obj.name
|
| 334 |
+
augmented_only.to_csv(augmented_output_path, index=False)
|
| 335 |
+
|
| 336 |
+
# SMOTENC ๋ฐ์ดํฐ์ ํํฐ๋ง๋ CTGAN ์ํ ๋ณํฉ (์ต์ข
๊ฒฐ๊ณผ์ฉ)
|
| 337 |
+
smote_gan_data = pd.concat([smotenc_data, well_generated_0, well_generated_1], axis=0)
|
| 338 |
+
|
| 339 |
+
# ํ์ ๋ณ์ ์ถ๊ฐ
|
| 340 |
+
smote_gan_data = add_derived_features(smote_gan_data)
|
| 341 |
+
|
| 342 |
+
# ์ฆ๊ฐ๋ ๋ฐ์ดํฐ๋ง ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 343 |
+
aug_count_0 = len(augmented_only[augmented_only['multi_class'] == 0])
|
| 344 |
+
aug_count_1 = len(augmented_only[augmented_only['multi_class'] == 1])
|
| 345 |
+
print(f"Saved augmented data only {augmented_output_path}: Class 0={aug_count_0} | Class 1={aug_count_1}")
|
| 346 |
+
|
| 347 |
+
# ํด๋์ค 2 ์ ๊ฑฐ ํ ์๋ณธ ํด๋์ค 2 ๋ฐ์ดํฐ ์ถ๊ฐ
|
| 348 |
+
filtered_data = smote_gan_data[smote_gan_data['multi_class'] != 2]
|
| 349 |
+
original_class_2 = original_data[original_data['multi_class'] == 2]
|
| 350 |
+
final_data = pd.concat([filtered_data, original_class_2], axis=0)
|
| 351 |
+
final_data.reset_index(drop=True, inplace=True)
|
| 352 |
+
|
| 353 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 354 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 357 |
+
final_data.to_csv(output_path, index=False)
|
| 358 |
+
|
| 359 |
+
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 360 |
+
count_0 = len(final_data[final_data['multi_class'] == 0])
|
| 361 |
+
count_1 = len(final_data[final_data['multi_class'] == 1])
|
| 362 |
+
count_2 = len(final_data[final_data['multi_class'] == 2])
|
| 363 |
+
print(f"Saved {output_path}: Class 0={count_0} | Class 1={count_1} | Class 2={count_2}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ==================== ๋ฉ์ธ ์คํ ====================
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
setup_environment()
|
| 370 |
+
|
| 371 |
+
file_paths = [f'../../../data/data_for_modeling/{region}_train.csv' for region in REGIONS]
|
| 372 |
+
output_paths = [f'../../../data/data_oversampled/smotenc_ctgan7000/smotenc_ctgan7000_3_{region}.csv' for region in REGIONS]
|
| 373 |
+
model_save_dir = Path('../../save_model/oversampling_models')
|
| 374 |
+
|
| 375 |
+
for file_path, output_path in zip(file_paths, output_paths):
|
| 376 |
+
process_region(file_path, output_path, model_save_dir)
|
Analysis_code/3.sampled_data_analysis/make_plot.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
๋ฐ์ดํฐ ์๊ฐํ ๋ชจ๋: Original๊ณผ Synthetic ๋ฐ์ดํฐ ๋น๊ต ์๊ฐํ
|
| 3 |
+
|
| 4 |
+
์ด ๋ชจ๋์ ์๋ณธ ๋ฐ์ดํฐ์ ํฉ์ฑ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ , ์ ์ฒ๋ฆฌํ ํ
|
| 5 |
+
UMAP์ ์ฌ์ฉํ์ฌ ์ฐจ์ ์ถ์ ๋ฐ ์๊ฐํ๋ฅผ ์ํํฉ๋๋ค.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
# TensorFlow ๋ก๊ทธ ๋ฉ์์ง ์จ๊ธฐ๊ธฐ
|
| 10 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 0=๋ชจ๋, 1=INFO ์ ์ธ, 2=INFO/WARNING ์ ์ธ, 3=ERROR๋ง
|
| 11 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # oneDNN ๊ฒฝ๊ณ ์จ๊ธฐ๊ธฐ
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
import warnings
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import List, Tuple, Optional
|
| 20 |
+
from sklearn.preprocessing import StandardScaler
|
| 21 |
+
import umap
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class PlotConfig:
|
| 27 |
+
"""์๊ฐํ ์ค์ ๊ฐ์ ๊ด๋ฆฌํ๋ ํด๋์ค"""
|
| 28 |
+
cols_to_drop: List[str] = None
|
| 29 |
+
umap_n_neighbors: int = 30
|
| 30 |
+
umap_min_dist: float = 0.1
|
| 31 |
+
umap_random_state: int = 42
|
| 32 |
+
umap_n_jobs: int = 1 # random_state ์ค์ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ๋ถ๊ฐ (๊ฒฝ๊ณ ๋ฐฉ์ง)
|
| 33 |
+
figsize: Tuple[int, int] = (16, 6)
|
| 34 |
+
alpha: float = 0.6 # Original๊ณผ Synthetic ๋ฐ์ดํฐ ๋ชจ๋ ๋์ผํ ํฌ๋ช
๋
|
| 35 |
+
visibility_threshold: int = 500
|
| 36 |
+
scale_on_original_only: bool = True # True: ์๋ณธ ๊ธฐ์ค ์ค์ผ์ผ๋ง (๋ฐ์ดํฐ ๋์ค ๋ฐฉ์ง), False: ํฉ์ณ์ ์ค์ผ์ผ๋ง
|
| 37 |
+
|
| 38 |
+
def __post_init__(self):
|
| 39 |
+
"""๊ธฐ๋ณธ๊ฐ ์ค์ """
|
| 40 |
+
if self.cols_to_drop is None:
|
| 41 |
+
self.cols_to_drop = [
|
| 42 |
+
'wind_dir', # ๋ฌธ์์ด (์๋ฌ ๋ฐ์)
|
| 43 |
+
'multi_class', # ํ๊ฒ ๋ณ์ (์๊ฐํ์ฉ ์๊น๋ก๋ง ์ฌ์ฉ)
|
| 44 |
+
'binary_class', # ํ๊ฒ ๋ณ์
|
| 45 |
+
'year', 'month', 'hour', # sin/cos ๋ณ์์ ์ค๋ณต
|
| 46 |
+
'ground_temp - temp_C', # ๋จ์ ์ ํ ๊ฒฐํฉ (์ ๋ณด ์ค๋ณต)
|
| 47 |
+
'visi'
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_time_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 52 |
+
"""
|
| 53 |
+
์
๋ ฅ ๋ฐ์ดํฐํ๋ ์์ ์๊ฐ ๊ด๋ จ ํ์๋ณ์๋ฅผ ์ถ๊ฐํฉ๋๋ค.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
df: ์
๋ ฅ ๋ฐ์ดํฐํ๋ ์ (hour, month ์ปฌ๋ผ ํ์)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
์๊ฐ ํน์ฑ์ด ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 60 |
+
"""
|
| 61 |
+
df = df.copy()
|
| 62 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 63 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 64 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 65 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 66 |
+
return df
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_binary_class(visi: pd.Series, threshold: int = 500) -> pd.Series:
|
| 70 |
+
"""
|
| 71 |
+
๊ฐ์๋(visi) ๊ฐ์ ๊ธฐ๋ฐ์ผ๋ก ์ด์ง ๋ถ๋ฅ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
visi: ๊ฐ์๋ ๊ฐ ์๋ฆฌ์ฆ
|
| 75 |
+
threshold: ์ด์ง ๋ถ๋ฅ ์๊ณ๊ฐ (๊ธฐ๋ณธ๊ฐ: 500)
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
์ด์ง ๋ถ๋ฅ ๊ฒฐ๊ณผ (1: < threshold, 0: >= threshold)
|
| 79 |
+
"""
|
| 80 |
+
return visi.apply(lambda x: 1 if x < threshold else (0 if x >= threshold else np.nan))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_region_data(
|
| 84 |
+
region: str,
|
| 85 |
+
data_dir: str = "../../data/data_for_modeling"
|
| 86 |
+
) -> pd.DataFrame:
|
| 87 |
+
"""
|
| 88 |
+
ํน์ ์ง์ญ์ ์๋ณธ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
region: ์ง์ญ๋ช
('incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju')
|
| 92 |
+
data_dir: ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
๋ก๋๋ ์ง์ญ ๋ฐ์ดํฐํ๋ ์
|
| 96 |
+
"""
|
| 97 |
+
file_path = f"{data_dir}/{region}_train.csv"
|
| 98 |
+
df = pd.read_csv(file_path)
|
| 99 |
+
|
| 100 |
+
# ํ์ํ ์ปฌ๋ผ๋ง ์ ํ
|
| 101 |
+
required_cols = [
|
| 102 |
+
'temp_C', 'precip_mm', 'wind_speed', 'wind_dir', 'hm', 'vap_pressure',
|
| 103 |
+
'dewpoint_C', 'loc_pressure', 'sea_pressure', 'solarRad', 'snow_cm',
|
| 104 |
+
'cloudcover', 'lm_cloudcover', 'low_cloudbase', 'groundtemp', 'O3',
|
| 105 |
+
'NO2', 'PM10', 'PM25', 'year', 'month', 'hour', 'visi', 'multi_class',
|
| 106 |
+
'binary_class', 'hour_sin', 'hour_cos', 'month_sin', 'month_cos',
|
| 107 |
+
'ground_temp - temp_C'
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
# ์กด์ฌํ๋ ์ปฌ๋ผ๋ง ์ ํ
|
| 111 |
+
available_cols = [col for col in required_cols if col in df.columns]
|
| 112 |
+
df = df.loc[:, available_cols].copy()
|
| 113 |
+
|
| 114 |
+
return df
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def load_and_preprocess_data(
|
| 118 |
+
synthetic_path: str,
|
| 119 |
+
config: PlotConfig,
|
| 120 |
+
region: Optional[str] = None,
|
| 121 |
+
fold_idx: Optional[int] = None,
|
| 122 |
+
data_dir: str = "../../data/data_for_modeling",
|
| 123 |
+
original_path: Optional[str] = None
|
| 124 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 125 |
+
"""
|
| 126 |
+
์๋ณธ ๋ฐ ํฉ์ฑ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
synthetic_path: ํฉ์ฑ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 130 |
+
config: PlotConfig ๊ฐ์ฒด
|
| 131 |
+
region: ์ง์ญ๋ช
('incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju')
|
| 132 |
+
original_path๊ฐ None์ผ ๋ ์ฌ์ฉ
|
| 133 |
+
fold_idx: fold ์ธ๋ฑ์ค (0, 1, 2 ์ค ํ๋), None์ด๋ฉด ์ ์ฒด ๋ฐ์ดํฐ
|
| 134 |
+
original_path๊ฐ None์ผ ๋ ์ฌ์ฉ
|
| 135 |
+
data_dir: ์๋ณธ ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก (region ์ฌ์ฉ ์)
|
| 136 |
+
original_path: ์๋ณธ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก (์ง์ ํ๋ฉด region/fold ๋ฌด์)
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
(์ ์ฒ๋ฆฌ๋ ์๋ณธ ๋ฐ์ดํฐ, ์ ์ฒ๋ฆฌ๋ ํฉ์ฑ ๋ฐ์ดํฐ) ํํ
|
| 140 |
+
"""
|
| 141 |
+
# ์๋ณธ ๋ฐ์ดํฐ ๋ก๋
|
| 142 |
+
if original_path is not None:
|
| 143 |
+
# ๊ธฐ์กด ๋ฐฉ์: ํ์ผ ๊ฒฝ๋ก๋ก ์ง์ ๋ก๋
|
| 144 |
+
original_data = pd.read_csv(original_path)
|
| 145 |
+
elif region is not None:
|
| 146 |
+
# ์๋ก์ด ๋ฐฉ์: ์ง์ญ๊ณผ fold๋ก ๋ก๋
|
| 147 |
+
original_data = load_region_data(region, data_dir)
|
| 148 |
+
|
| 149 |
+
# fold์ ๋ฐ๋ผ ํํฐ๋ง
|
| 150 |
+
if fold_idx is not None:
|
| 151 |
+
fold = [[2018, 2019], [2018, 2020], [2019, 2020]]
|
| 152 |
+
if 0 <= fold_idx < len(fold):
|
| 153 |
+
years = fold[fold_idx]
|
| 154 |
+
original_data = original_data.loc[original_data['year'].isin(years), :].copy()
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError("original_path ๋๋ region์ ์ง์ ํด์ผ ํฉ๋๋ค.")
|
| 157 |
+
|
| 158 |
+
# ํฉ์ฑ ๋ฐ์ดํฐ ๋ก๋
|
| 159 |
+
synthetic_data = pd.read_csv(synthetic_path)
|
| 160 |
+
|
| 161 |
+
# ์ด์ง ๋ถ๋ฅ ์์ฑ
|
| 162 |
+
original_data['binary_class'] = create_binary_class(
|
| 163 |
+
original_data['visi'],
|
| 164 |
+
config.visibility_threshold
|
| 165 |
+
)
|
| 166 |
+
synthetic_data['binary_class'] = create_binary_class(
|
| 167 |
+
synthetic_data['visi'],
|
| 168 |
+
config.visibility_threshold
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# ์๊ฐ ํน์ฑ ์ถ๊ฐ
|
| 172 |
+
original_data = add_time_features(original_data)
|
| 173 |
+
synthetic_data = add_time_features(synthetic_data)
|
| 174 |
+
|
| 175 |
+
# multi_class ํํฐ๋ง (Original๋ง)
|
| 176 |
+
original_data = original_data.loc[original_data['multi_class'].isin([0, 1]), :]
|
| 177 |
+
|
| 178 |
+
# ๋ผ๋ฒจ ์ถ๊ฐ
|
| 179 |
+
original_data['Label'] = 'Original'
|
| 180 |
+
synthetic_data['Label'] = 'Synthetic'
|
| 181 |
+
|
| 182 |
+
# ๋ถํ์ํ ์ปฌ๋ผ ์ ๊ฑฐ
|
| 183 |
+
original_data = original_data.drop(config.cols_to_drop, axis=1)
|
| 184 |
+
synthetic_data = synthetic_data.drop(config.cols_to_drop, axis=1)
|
| 185 |
+
|
| 186 |
+
return original_data, synthetic_data
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def prepare_features_for_visualization(
|
| 190 |
+
original_data: pd.DataFrame,
|
| 191 |
+
synthetic_data: pd.DataFrame,
|
| 192 |
+
config: PlotConfig
|
| 193 |
+
) -> Tuple[np.ndarray, pd.Series, StandardScaler]:
|
| 194 |
+
"""
|
| 195 |
+
์๊ฐํ๋ฅผ ์ํ ํผ์ฒ๋ฅผ ์ค๋นํ๊ณ ์ค์ผ์ผ๋งํฉ๋๋ค.
|
| 196 |
+
|
| 197 |
+
์ค์: ๋ฐ์ดํฐ ๋์ค์ ๋ฐฉ์งํ๊ธฐ ์ํด ๊ธฐ๋ณธ์ ์ผ๋ก ์๋ณธ ๋ฐ์ดํฐ๋ก๋ง scaler๋ฅผ fitํ๊ณ ,
|
| 198 |
+
ํฉ์ฑ ๋ฐ์ดํฐ๋ transform๋ง ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ํฉ์ฑ ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ ์๋ณธ ๋ฐ์ดํฐ์
|
| 199 |
+
์ค์ผ์ผ๋ง์ ์ํฅ์ ์ฃผ์ง ์์ต๋๋ค.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
original_data: ์๋ณธ ๋ฐ์ดํฐํ๋ ์
|
| 203 |
+
synthetic_data: ํฉ์ฑ ๋ฐ์ดํฐํ๋ ์
|
| 204 |
+
config: PlotConfig ๊ฐ์ฒด (scale_on_original_only ์ค์ ํฌํจ)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
(์ค์ผ์ผ๋ง๋ ํผ์ฒ, ๋ผ๋ฒจ, ์ค์ผ์ผ๋ฌ) ํํ
|
| 208 |
+
"""
|
| 209 |
+
# ํผ์ฒ์ ๋ผ๋ฒจ ๋ถ๋ฆฌ
|
| 210 |
+
original_features = original_data.drop('Label', axis=1)
|
| 211 |
+
synthetic_features = synthetic_data.drop('Label', axis=1)
|
| 212 |
+
|
| 213 |
+
if config.scale_on_original_only:
|
| 214 |
+
# ๋ฐฉ๋ฒ 1: ์๋ณธ ๋ฐ์ดํฐ๋ก๋ง scaler fit (๋ฐ์ดํฐ ๋์ค ๋ฐฉ์ง, ๊ถ์ฅ)
|
| 215 |
+
# ์ด ๋ฐฉ๋ฒ์ ํฉ์ฑ ๋ฐ์ดํฐ๊ฐ ์๋ณธ ๋ฐ์ดํฐ์ ์ค์ผ์ผ๋ง์ ์ํฅ์ ์ฃผ์ง ์์ต๋๋ค.
|
| 216 |
+
scaler = StandardScaler()
|
| 217 |
+
scaled_original = scaler.fit_transform(original_features)
|
| 218 |
+
scaled_synthetic = scaler.transform(synthetic_features)
|
| 219 |
+
|
| 220 |
+
# ์ค์ผ์ผ๋ง๋ ๋ฐ์ดํฐ ํฉ์น๊ธฐ
|
| 221 |
+
scaled_features = np.vstack([scaled_original, scaled_synthetic])
|
| 222 |
+
|
| 223 |
+
# ๋ผ๋ฒจ ํฉ์น๊ธฐ
|
| 224 |
+
labels = pd.concat([
|
| 225 |
+
original_data['Label'],
|
| 226 |
+
synthetic_data['Label']
|
| 227 |
+
], ignore_index=True)
|
| 228 |
+
else:
|
| 229 |
+
# ๋ฐฉ๋ฒ 2: ํฉ์ณ์ ์ค์ผ์ผ๋ง (๋ฐ์ดํฐ ๋์ค ์์, ๋น๊ต ๋ชฉ์ ์ผ ๋๋ง ์ฌ์ฉ)
|
| 230 |
+
# ์ฃผ์: ์ด ๋ฐฉ๋ฒ์ ํฉ์ฑ ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ ์๋ณธ ์ค์ผ์ผ๋ง์ ์ํฅ์ ์ค๋๋ค.
|
| 231 |
+
combined_df = pd.concat([original_data, synthetic_data], ignore_index=True)
|
| 232 |
+
features = combined_df.drop('Label', axis=1)
|
| 233 |
+
labels = combined_df['Label']
|
| 234 |
+
|
| 235 |
+
scaler = StandardScaler()
|
| 236 |
+
scaled_features = scaler.fit_transform(features)
|
| 237 |
+
|
| 238 |
+
return scaled_features, labels, scaler
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def plot_umap_comparison(
|
| 242 |
+
scaled_features: np.ndarray,
|
| 243 |
+
labels: pd.Series,
|
| 244 |
+
config: PlotConfig,
|
| 245 |
+
region: Optional[str] = None,
|
| 246 |
+
fold_idx: Optional[int] = None,
|
| 247 |
+
ax: Optional[plt.Axes] = None
|
| 248 |
+
) -> plt.Figure:
|
| 249 |
+
"""
|
| 250 |
+
UMAP์ ์ฌ์ฉํ์ฌ ์ฐจ์ ์ถ์ ํ Original๊ณผ Synthetic ๋ฐ์ดํฐ๋ฅผ ๋น๊ต ์๊ฐํํฉ๋๋ค.
|
| 251 |
+
|
| 252 |
+
ํต์ฌ: ์๋ณธ ๋ฐ์ดํฐ๊ฐ ์ ์ํ ๊ณต๊ฐ(Manifold) ์์ ํฉ์ฑ ๋ฐ์ดํฐ๋ฅผ ํฌ์ํฉ๋๋ค.
|
| 253 |
+
- Original ๋ฐ์ดํฐ๋ก๋ง UMAP์ fitํ์ฌ ๊ณต๊ฐ ๊ตฌ์กฐ๋ฅผ ํ์ต
|
| 254 |
+
- Synthetic ๋ฐ์ดํฐ๋ ํ์ต๋ ๊ณต๊ฐ์ transform๋ง ์ ์ฉ
|
| 255 |
+
- ์ด๋ ๊ฒ ํ๋ฉด ํฉ์ฑ ๋ฐ์ดํฐ๊ฐ ์๋ณธ ๋ฐ์ดํฐ์ ๊ณต๊ฐ ํ์ฑ์ ์ํฅ์ ์ฃผ์ง ์์ต๋๋ค.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
scaled_features: ์ค์ผ์ผ๋ง๋ ํผ์ฒ ๋ฐฐ์ด
|
| 259 |
+
labels: ๋ฐ์ดํฐ ๋ผ๋ฒจ (Original/Synthetic)
|
| 260 |
+
config: PlotConfig ๊ฐ์ฒด
|
| 261 |
+
region: ์ง์ญ๋ช
(ํ์์ฉ)
|
| 262 |
+
fold_idx: fold ์ธ๋ฑ์ค (ํ์์ฉ)
|
| 263 |
+
ax: matplotlib axes ๊ฐ์ฒด (None์ด๏ฟฝ๏ฟฝ๏ฟฝ ์ figure ์์ฑ)
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
matplotlib Figure ๊ฐ์ฒด
|
| 267 |
+
"""
|
| 268 |
+
print("UMAP ์คํ ์ค... (Original ๊ธฐ์ค ํ์ต ํ Synthetic ๋ณํ)")
|
| 269 |
+
|
| 270 |
+
# 1. ๋ฐ์ดํฐ ๋ถ๋ฆฌ (Labels๋ฅผ ์ด์ฉํด์ ๋ค์ ๋๋)
|
| 271 |
+
is_original = labels == 'Original'
|
| 272 |
+
original_data = scaled_features[is_original]
|
| 273 |
+
synthetic_data = scaled_features[~is_original]
|
| 274 |
+
|
| 275 |
+
# 2. UMAP ๋ชจ๋ธ ์์ฑ
|
| 276 |
+
umap_model = umap.UMAP(
|
| 277 |
+
n_neighbors=config.umap_n_neighbors,
|
| 278 |
+
min_dist=config.umap_min_dist,
|
| 279 |
+
random_state=config.umap_random_state,
|
| 280 |
+
n_jobs=config.umap_n_jobs
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# 3. [ํต์ฌ] Original ๋ฐ์ดํฐ๋ก๋ง ๊ณต๊ฐ ํ์ต (Fit)
|
| 284 |
+
# ์๋ณธ ๋ฐ์ดํฐ์ ๊ตฌ์กฐ(Manifold)๋ง ํ์ตํฉ๋๋ค.
|
| 285 |
+
original_embedding = umap_model.fit_transform(original_data)
|
| 286 |
+
|
| 287 |
+
# 4. [ํต์ฌ] ํ์ต๋ ๊ณต๊ฐ์ Synthetic ๋ฐ์ดํฐ ํฌ์ (Transform)
|
| 288 |
+
# ํฉ์ฑ ๋ฐ์ดํฐ๋ ๊ณต๊ฐ ํ์ฑ์ ๊ด์ฌํ์ง ์๊ณ , ์ด๋ฏธ ๋ง๋ค์ด์ง ๊ณต๊ฐ์ ์์น๋ง ์ฐพ์ต๋๋ค.
|
| 289 |
+
synthetic_embedding = umap_model.transform(synthetic_data)
|
| 290 |
+
|
| 291 |
+
# 5. ๊ฒฐ๊ณผ ํฉ์น๊ธฐ (์๊ฐํ๋ฅผ ์ํด)
|
| 292 |
+
umap_results = np.vstack([original_embedding, synthetic_embedding])
|
| 293 |
+
|
| 294 |
+
# ์์ ๋ณด์ฅ์ ์ํด ๋ผ๋ฒจ๋ ๋ค์ ์ ๋ฆฌ (Original์ด ์, Synthetic์ด ๋ค)
|
| 295 |
+
combined_labels = pd.concat([
|
| 296 |
+
labels[is_original],
|
| 297 |
+
labels[~is_original]
|
| 298 |
+
], ignore_index=True)
|
| 299 |
+
|
| 300 |
+
# ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ
|
| 301 |
+
df_umap = pd.DataFrame(umap_results, columns=['UMAP1', 'UMAP2'])
|
| 302 |
+
df_umap['Label'] = combined_labels
|
| 303 |
+
|
| 304 |
+
# ์ง์ญ ๋ฐ fold ์ ๋ณด ๋ฌธ์์ด ์์ฑ (title์ ์ฌ์ฉ)
|
| 305 |
+
title_parts = ["UMAP: Original vs Synthetic"]
|
| 306 |
+
if region is not None:
|
| 307 |
+
if fold_idx is not None:
|
| 308 |
+
fold = [[2018, 2019], [2018, 2020], [2019, 2020]]
|
| 309 |
+
if 0 <= fold_idx < len(fold):
|
| 310 |
+
years = fold[fold_idx]
|
| 311 |
+
fold_display = fold_idx + 1 # fold๋ฅผ +1ํด์ ํ์
|
| 312 |
+
title_parts.append(f"Region: {region.upper()} | Fold {fold_display}: {years[0]}-{years[1]}")
|
| 313 |
+
else:
|
| 314 |
+
title_parts.append(f"Region: {region.upper()}")
|
| 315 |
+
else:
|
| 316 |
+
title_parts.append(f"Region: {region.upper()}")
|
| 317 |
+
elif fold_idx is not None:
|
| 318 |
+
fold = [[2018, 2019], [2018, 2020], [2019, 2020]]
|
| 319 |
+
if 0 <= fold_idx < len(fold):
|
| 320 |
+
years = fold[fold_idx]
|
| 321 |
+
fold_display = fold_idx + 1 # fold๋ฅผ +1ํด์ ํ์
|
| 322 |
+
title_parts.append(f"Fold {fold_display}: {years[0]}-{years[1]}")
|
| 323 |
+
title_str = " - ".join(title_parts)
|
| 324 |
+
|
| 325 |
+
# Figure ๋ฐ Axes ์ค์ (๋จ์ผ ํ๋กฏ)
|
| 326 |
+
if ax is None:
|
| 327 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
| 328 |
+
else:
|
| 329 |
+
fig = ax.figure if hasattr(ax, 'figure') else plt.gcf()
|
| 330 |
+
|
| 331 |
+
# ์ ์ฒด ๋ฐ์ดํฐ์ UMAP ๋ฒ์ ๊ณ์ฐ
|
| 332 |
+
x_min = df_umap['UMAP1'].min() - 1
|
| 333 |
+
x_max = df_umap['UMAP1'].max() + 1
|
| 334 |
+
y_min = df_umap['UMAP2'].min() - 1
|
| 335 |
+
y_max = df_umap['UMAP2'].max() + 1
|
| 336 |
+
|
| 337 |
+
# Synthetic ๋ฐ์ดํฐ ์๊ฐํ (๋นจ๊ฐ์, ๋จผ์ ๊ทธ๋ ค์ ๋ค์ ์์น)
|
| 338 |
+
sns.scatterplot(
|
| 339 |
+
data=df_umap.loc[df_umap['Label'] == 'Synthetic'],
|
| 340 |
+
x='UMAP1', y='UMAP2',
|
| 341 |
+
color='red',
|
| 342 |
+
alpha=config.alpha,
|
| 343 |
+
label='Synthetic',
|
| 344 |
+
ax=ax,
|
| 345 |
+
s=30
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Original ๋ฐ์ดํฐ ์๊ฐํ (ํ๋์, ๋์ค์ ๊ทธ๋ ค์ ์์ ์์นํ์ฌ ๋ ์ ๋ณด์ด๊ฒ)
|
| 349 |
+
sns.scatterplot(
|
| 350 |
+
data=df_umap.loc[df_umap['Label'] == 'Original'],
|
| 351 |
+
x='UMAP1', y='UMAP2',
|
| 352 |
+
color='blue',
|
| 353 |
+
alpha=config.alpha,
|
| 354 |
+
label='Original',
|
| 355 |
+
ax=ax,
|
| 356 |
+
s=30
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
ax.set_xlim(x_min, x_max)
|
| 360 |
+
ax.set_ylim(y_min, y_max)
|
| 361 |
+
ax.set_xlabel('UMAP1', fontsize=12)
|
| 362 |
+
ax.set_ylabel('UMAP2', fontsize=12)
|
| 363 |
+
ax.set_title(title_str, fontsize=14, fontweight='bold')
|
| 364 |
+
ax.legend(title='Label', loc='best')
|
| 365 |
+
ax.grid(True, alpha=0.3)
|
| 366 |
+
|
| 367 |
+
plt.tight_layout()
|
| 368 |
+
return fig
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def generate_synthetic_path(
|
| 372 |
+
method: str,
|
| 373 |
+
region: str,
|
| 374 |
+
sample_size: Optional[int] = None,
|
| 375 |
+
fold_idx: Optional[int] = None,
|
| 376 |
+
base_dir: str = "../../data/data_oversampled"
|
| 377 |
+
) -> str:
|
| 378 |
+
"""
|
| 379 |
+
ํฉ์ฑ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก๋ฅผ ์์ฑํฉ๋๋ค.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
method: ์ฆ๊ฐ ๋ฐฉ๋ฒ ('ctgan', 'smotenc_ctgan', 'smote')
|
| 383 |
+
region: ์ง์ญ๋ช
|
| 384 |
+
sample_size: ์ํ ์ (ctgan, smotenc_ctgan์ธ ๊ฒฝ์ฐ ํ์, smote์ธ ๊ฒฝ์ฐ ๋ฌด์)
|
| 385 |
+
fold_idx: fold ์ธ๋ฑ์ค (0, 1, 2 ์ค ํ๋, None์ด๋ฉด 0 ์ฌ์ฉ)
|
| 386 |
+
base_dir: ๊ธฐ๋ณธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
ํฉ์ฑ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
| 390 |
+
"""
|
| 391 |
+
# fold_idx ๊ธฐ๋ณธ๊ฐ ์ค์ (None์ด๋ฉด 0, ์ฆ fold 1)
|
| 392 |
+
if fold_idx is None:
|
| 393 |
+
fold_idx = 0
|
| 394 |
+
fold_num = fold_idx + 1 # ํ์ผ๋ช
์ 1๋ถํฐ ์์ (fold_idx๋ 0๋ถํฐ)
|
| 395 |
+
|
| 396 |
+
if method == 'ctgan':
|
| 397 |
+
if sample_size is None:
|
| 398 |
+
raise ValueError("ctgan ๋ฐฉ๋ฒ์ sample_size๊ฐ ํ์ํฉ๋๋ค (7000, 10000, 20000 ์ค ์ ํ)")
|
| 399 |
+
if sample_size not in [7000, 10000, 20000]:
|
| 400 |
+
raise ValueError(f"sample_size๋ 7000, 10000, 20000 ์ค ํ๋์ฌ์ผ ํฉ๋๋ค. ์
๋ ฅ๊ฐ: {sample_size}")
|
| 401 |
+
return f"{base_dir}/augmented_only/ctgan{sample_size}_{fold_num}_{region}.csv"
|
| 402 |
+
|
| 403 |
+
elif method == 'smotenc_ctgan':
|
| 404 |
+
if sample_size is None:
|
| 405 |
+
raise ValueError("smotenc_ctgan ๋ฐฉ๋ฒ์ sample_size๊ฐ ํ์ํฉ๋๋ค (7000, 10000, 20000 ์ค ์ ํ)")
|
| 406 |
+
if sample_size not in [7000, 10000, 20000]:
|
| 407 |
+
raise ValueError(f"sample_size๋ 7000, 10000, 20000 ์ค ํ๋์ฌ์ผ ํฉ๋๋ค. ์
๋ ฅ๊ฐ: {sample_size}")
|
| 408 |
+
return f"{base_dir}/augmented_only/smotenc_ctgan{sample_size}_{fold_num}_{region}.csv"
|
| 409 |
+
|
| 410 |
+
elif method == 'smote':
|
| 411 |
+
# smote๋ sample_size๋ฅผ ์ฌ์ฉํ์ง ์์ผ๋ฏ๋ก ๋ฌด์
|
| 412 |
+
# smote ํ์ผ๋ augmented_only์ ์๋ค๊ณ ๊ฐ์ (fold ๋ฒํธ ํฌํจ ์ฌ๋ถ ํ์ธ ํ์)
|
| 413 |
+
return f"{base_dir}/augmented_only/smote_{fold_num}_{region}.csv"
|
| 414 |
+
|
| 415 |
+
else:
|
| 416 |
+
raise ValueError(f"์ง์ํ์ง ์๋ method์
๋๋ค: {method}. 'ctgan', 'smotenc_ctgan', 'smote' ์ค ํ๋๋ฅผ ์ ํํ์ธ์.")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def main(
|
| 420 |
+
method: str = "ctgan",
|
| 421 |
+
sample_size: Optional[int] = 7000,
|
| 422 |
+
config: Optional[PlotConfig] = None,
|
| 423 |
+
region: Optional[str] = "busan",
|
| 424 |
+
fold_idx: Optional[int] = 0,
|
| 425 |
+
data_dir: str = "../../data/data_for_modeling",
|
| 426 |
+
original_path: Optional[str] = None,
|
| 427 |
+
synthetic_path: Optional[str] = None,
|
| 428 |
+
base_dir: str = "../../data/data_oversampled"
|
| 429 |
+
) -> None:
|
| 430 |
+
"""
|
| 431 |
+
์ ์ฒด ํ์ดํ๋ผ์ธ์ ์คํํ๋ ๋ฉ์ธ ํจ์.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
method: ์ฆ๊ฐ ๋ฐฉ๋ฒ ('ctgan', 'smotenc_ctgan', 'smote')
|
| 435 |
+
sample_size: ์ํ ์ (ctgan, smotenc_ctgan์ธ ๊ฒฝ์ฐ: 7000, 10000, 20000 ์ค ์ ํ, smote์ธ ๊ฒฝ์ฐ ๋ฌด์)
|
| 436 |
+
config: PlotConfig ๊ฐ์ฒด (None์ด๋ฉด ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ)
|
| 437 |
+
region: ์ง์ญ๋ช
('incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju')
|
| 438 |
+
original_path๊ฐ None์ผ ๋ ์ฌ์ฉ
|
| 439 |
+
fold_idx: fold ์ธ๋ฑ์ค (0, 1, 2 ์ค ํ๋), None์ด๋ฉด ์ ์ฒด ๋ฐ์ดํฐ
|
| 440 |
+
original_path๊ฐ None์ผ ๋ ์ฌ์ฉ
|
| 441 |
+
data_dir: ์๋ณธ ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก (region ์ฌ์ฉ ์)
|
| 442 |
+
original_path: ์๋ณธ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก (์ง์ ํ๋ฉด region/fold ๋ฌด์)
|
| 443 |
+
synthetic_path: ํฉ์ฑ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก (์ง์ ํ๋ฉด method/sample_size ๋ฌด์)
|
| 444 |
+
base_dir: ํฉ์ฑ ๋ฐ์ดํฐ ๊ธฐ๋ณธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 445 |
+
"""
|
| 446 |
+
if config is None:
|
| 447 |
+
config = PlotConfig()
|
| 448 |
+
|
| 449 |
+
# ํฉ์ฑ ๋ฐ์ดํฐ ๊ฒฝ๋ก ์์ฑ
|
| 450 |
+
if synthetic_path is None:
|
| 451 |
+
if region is None:
|
| 452 |
+
raise ValueError("synthetic_path๋ฅผ ์ง์ ํ์ง ์์ผ๋ฉด region์ด ํ์ํฉ๋๋ค.")
|
| 453 |
+
synthetic_path = generate_synthetic_path(method, region, sample_size, fold_idx, base_dir)
|
| 454 |
+
|
| 455 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 456 |
+
original_data, synthetic_data = load_and_preprocess_data(
|
| 457 |
+
synthetic_path=synthetic_path,
|
| 458 |
+
config=config,
|
| 459 |
+
region=region,
|
| 460 |
+
fold_idx=fold_idx,
|
| 461 |
+
data_dir=data_dir,
|
| 462 |
+
original_path=original_path
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# ํผ์ฒ ์ค๋น ๋ฐ ์ค์ผ์ผ๋ง
|
| 466 |
+
scaled_features, labels, scaler = prepare_features_for_visualization(
|
| 467 |
+
original_data, synthetic_data, config
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# UMAP ์๊ฐํ
|
| 471 |
+
plot_umap_comparison(
|
| 472 |
+
scaled_features,
|
| 473 |
+
labels,
|
| 474 |
+
config,
|
| 475 |
+
region=region,
|
| 476 |
+
fold_idx=fold_idx
|
| 477 |
+
)
|
| 478 |
+
plt.show()
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def generate_all_plots(
|
| 482 |
+
output_dir: str = "images",
|
| 483 |
+
config: Optional[PlotConfig] = None,
|
| 484 |
+
data_dir: str = "../../data/data_for_modeling",
|
| 485 |
+
base_dir: str = "../../data/data_oversampled"
|
| 486 |
+
) -> None:
|
| 487 |
+
"""
|
| 488 |
+
๋
ผ๋ฌธ ๊ฒ์ฌ๋ฅผ ์ํ ๋ชจ๋ ์กฐํฉ์ plot์ ์์ฑํ๊ณ ์ ์ฅํฉ๋๋ค.
|
| 489 |
+
|
| 490 |
+
์์ฑ๋๋ ์กฐํฉ:
|
| 491 |
+
- ์ง์ญ: incheon, seoul, busan, daegu, daejeon, gwangju (6๊ฐ)
|
| 492 |
+
- Fold: 0, 1, 2 (3๊ฐ)
|
| 493 |
+
- Method: ctgan, smotenc_ctgan, smote (3๊ฐ)
|
| 494 |
+
- Sample size (ctgan, smotenc_ctgan): 7000, 10000, 20000 (3๊ฐ)
|
| 495 |
+
|
| 496 |
+
์ด: (6 ์ง์ญ ร 3 fold ร 3 sample_size ร 2 methods) + (6 ์ง์ญ ร 3 fold ร 1 smote) = 126๊ฐ
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
output_dir: ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 500 |
+
config: PlotConfig ๊ฐ์ฒด (None์ด๋ฉด ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ)
|
| 501 |
+
data_dir: ์๋ณธ ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 502 |
+
base_dir: ํฉ์ฑ ๋ฐ์ดํฐ ๊ธฐ๋ณธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
| 503 |
+
"""
|
| 504 |
+
if config is None:
|
| 505 |
+
config = PlotConfig()
|
| 506 |
+
|
| 507 |
+
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ ์์ฑ
|
| 508 |
+
output_path = Path(output_dir)
|
| 509 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 510 |
+
|
| 511 |
+
# ๋ชจ๋ ์กฐํฉ ์ ์
|
| 512 |
+
regions = ['incheon', 'seoul', 'busan', 'daegu', 'daejeon', 'gwangju']
|
| 513 |
+
fold_indices = [0, 1, 2]
|
| 514 |
+
methods_with_size = [
|
| 515 |
+
('ctgan', 7000),
|
| 516 |
+
('ctgan', 10000),
|
| 517 |
+
('ctgan', 20000),
|
| 518 |
+
('smotenc_ctgan', 7000),
|
| 519 |
+
('smotenc_ctgan', 10000),
|
| 520 |
+
('smotenc_ctgan', 20000)
|
| 521 |
+
]
|
| 522 |
+
methods_without_size = [('smote', None)]
|
| 523 |
+
|
| 524 |
+
total_plots = len(regions) * len(fold_indices) * (len(methods_with_size) + len(methods_without_size))
|
| 525 |
+
current_plot = 0
|
| 526 |
+
|
| 527 |
+
print(f"์ด {total_plots}๊ฐ์ plot์ ์์ฑํฉ๋๋ค...")
|
| 528 |
+
print("=" * 60)
|
| 529 |
+
|
| 530 |
+
# Method์ sample_size๊ฐ ์๋ ๊ฒฝ์ฐ (ctgan, smotenc_ctgan)
|
| 531 |
+
for method, sample_size in methods_with_size:
|
| 532 |
+
for region in regions:
|
| 533 |
+
for fold_idx in fold_indices:
|
| 534 |
+
current_plot += 1
|
| 535 |
+
try:
|
| 536 |
+
print(f"[{current_plot}/{total_plots}] {method} (size={sample_size}) - {region.upper()} - Fold {fold_idx + 1} ์์ฑ ์ค...")
|
| 537 |
+
|
| 538 |
+
# ํฉ์ฑ ๋ฐ์ดํฐ ๊ฒฝ๋ก ์์ฑ
|
| 539 |
+
synthetic_path = generate_synthetic_path(method, region, sample_size, fold_idx, base_dir)
|
| 540 |
+
|
| 541 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 542 |
+
original_data, synthetic_data = load_and_preprocess_data(
|
| 543 |
+
synthetic_path=synthetic_path,
|
| 544 |
+
config=config,
|
| 545 |
+
region=region,
|
| 546 |
+
fold_idx=fold_idx,
|
| 547 |
+
data_dir=data_dir,
|
| 548 |
+
original_path=None
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# ํผ์ฒ ์ค๋น ๋ฐ ์ค์ผ์ผ๋ง
|
| 552 |
+
scaled_features, labels, scaler = prepare_features_for_visualization(
|
| 553 |
+
original_data, synthetic_data, config
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# UMAP ์๊ฐํ
|
| 557 |
+
fig = plot_umap_comparison(
|
| 558 |
+
scaled_features,
|
| 559 |
+
labels,
|
| 560 |
+
config,
|
| 561 |
+
region=region,
|
| 562 |
+
fold_idx=fold_idx
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# ํ์ผ๋ช
์์ฑ: method_sample_size_region_fold_years.png
|
| 566 |
+
fold = [[2018, 2019], [2018, 2020], [2019, 2020]]
|
| 567 |
+
years = fold[fold_idx]
|
| 568 |
+
filename = f"{method}_{sample_size}_{region}_fold{fold_idx + 1}_{years[0]}-{years[1]}.png"
|
| 569 |
+
filepath = output_path / filename
|
| 570 |
+
|
| 571 |
+
# ์ ์ฅ (๋
ผ๋ฌธ ๊ฒ์ฌ ํ์ง)
|
| 572 |
+
fig.savefig(
|
| 573 |
+
filepath,
|
| 574 |
+
dpi=600, # ํด์๋ (300dpi๋ ๋๋ถ๋ถ ์ ๋ ์๊ตฌ์ฌํญ)
|
| 575 |
+
bbox_inches='tight', # ์ฌ๋ฐฑ ์๋ ์ ๊ฑฐ
|
| 576 |
+
pad_inches=0.1, # tight์ผ ๋ ์ฝ๊ฐ์ ์ฌ๋ฐฑ ์ ์ง (๊ฐ๋
์ฑ)
|
| 577 |
+
facecolor='white', # ๋ฐฐ๊ฒฝ์ (ํฐ์)
|
| 578 |
+
edgecolor='none', # ํ
๋๋ฆฌ ์์
|
| 579 |
+
format='png', # ํ์ผ ํ์ (pdf๋ก ๋ณ๊ฒฝ ๊ฐ๋ฅ)
|
| 580 |
+
transparent= True # ํฌ๋ช
๋ฐฐ๊ฒฝ ์ฌ๋ถ
|
| 581 |
+
)
|
| 582 |
+
plt.close(fig)
|
| 583 |
+
|
| 584 |
+
print(f" โ ์ ์ฅ ์๋ฃ: {filename}")
|
| 585 |
+
|
| 586 |
+
except Exception as e:
|
| 587 |
+
print(f" โ ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 588 |
+
continue
|
| 589 |
+
|
| 590 |
+
# Method๋ง ์๊ณ sample_size๊ฐ ์๋ ๊ฒฝ์ฐ (smote)
|
| 591 |
+
for method, _ in methods_without_size:
|
| 592 |
+
for region in regions:
|
| 593 |
+
for fold_idx in fold_indices:
|
| 594 |
+
current_plot += 1
|
| 595 |
+
try:
|
| 596 |
+
print(f"[{current_plot}/{total_plots}] {method} - {region.upper()} - Fold {fold_idx + 1} ์์ฑ ์ค...")
|
| 597 |
+
|
| 598 |
+
# ํฉ์ฑ ๋ฐ์ดํฐ ๊ฒฝ๋ก ์์ฑ
|
| 599 |
+
synthetic_path = generate_synthetic_path(method, region, None, fold_idx, base_dir)
|
| 600 |
+
|
| 601 |
+
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
|
| 602 |
+
original_data, synthetic_data = load_and_preprocess_data(
|
| 603 |
+
synthetic_path=synthetic_path,
|
| 604 |
+
config=config,
|
| 605 |
+
region=region,
|
| 606 |
+
fold_idx=fold_idx,
|
| 607 |
+
data_dir=data_dir,
|
| 608 |
+
original_path=None
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# ํผ์ฒ ์ค๋น ๋ฐ ์ค์ผ์ผ๋ง
|
| 612 |
+
scaled_features, labels, scaler = prepare_features_for_visualization(
|
| 613 |
+
original_data, synthetic_data, config
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# UMAP ์๊ฐํ
|
| 617 |
+
fig = plot_umap_comparison(
|
| 618 |
+
scaled_features,
|
| 619 |
+
labels,
|
| 620 |
+
config,
|
| 621 |
+
region=region,
|
| 622 |
+
fold_idx=fold_idx
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# ํ์ผ๋ช
์์ฑ: method_region_fold_years.png
|
| 626 |
+
fold = [[2018, 2019], [2018, 2020], [2019, 2020]]
|
| 627 |
+
years = fold[fold_idx]
|
| 628 |
+
filename = f"{method}_{region}_fold{fold_idx + 1}_{years[0]}-{years[1]}.png"
|
| 629 |
+
filepath = output_path / filename
|
| 630 |
+
|
| 631 |
+
# ์ ์ฅ (๋
ผ๋ฌธ ๊ฒ์ฌ ํ์ง)
|
| 632 |
+
fig.savefig(
|
| 633 |
+
filepath,
|
| 634 |
+
dpi=300, # ํด์๋ (300dpi๋ ๋๋ถ๋ถ ์ ๋ ์๊ตฌ์ฌํญ)
|
| 635 |
+
bbox_inches='tight', # ์ฌ๋ฐฑ ์๋ ์ ๊ฑฐ
|
| 636 |
+
pad_inches=0.1, # tight์ผ ๋ ์ฝ๊ฐ์ ์ฌ๋ฐฑ ์ ์ง (๊ฐ๋
์ฑ)
|
| 637 |
+
facecolor='white', # ๋ฐฐ๊ฒฝ์ (ํฐ์)
|
| 638 |
+
edgecolor='none', # ํ
๋๋ฆฌ ์์
|
| 639 |
+
format='png', # ํ์ผ ํ์ (pdf๋ก ๋ณ๊ฒฝ ๊ฐ๋ฅ)
|
| 640 |
+
transparent=False # ํฌ๋ช
๋ฐฐ๊ฒฝ ์ฌ๋ถ
|
| 641 |
+
)
|
| 642 |
+
plt.close(fig)
|
| 643 |
+
|
| 644 |
+
print(f" โ ์ ์ฅ ์๋ฃ: {filename}")
|
| 645 |
+
|
| 646 |
+
except Exception as e:
|
| 647 |
+
print(f" โ ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 648 |
+
continue
|
| 649 |
+
|
| 650 |
+
print("=" * 60)
|
| 651 |
+
print(f"๋ชจ๋ plot ์์ฑ ์๋ฃ! ์ด {current_plot}๊ฐ ํ์ผ์ด {output_dir}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
if __name__ == "__main__":
|
| 655 |
+
# ๋จ์ผ plot ์์ฑ (๊ธฐ๋ณธ)
|
| 656 |
+
# main()
|
| 657 |
+
|
| 658 |
+
# ๋ชจ๋ ์กฐํฉ์ plot ์์ฑ (๋
ผ๋ฌธ์ฉ)
|
| 659 |
+
generate_all_plots(output_dir="images")
|
Analysis_code/3.sampled_data_analysis/oversampling_model_hyperparameter.ipynb
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "829c34fa",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"\"\"\"\n",
|
| 11 |
+
"CTGAN ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ ๋ฐ ์ ๋ฆฌ\n",
|
| 12 |
+
"๋
ผ๋ฌธ ์์ฑ์ฉ์ผ๋ก ๋ชจ๋ ์ ์ฅ๋ ๋ชจ๋ธ์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ถ์ถํฉ๋๋ค.\n",
|
| 13 |
+
"\"\"\"\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"import pandas as pd\n",
|
| 16 |
+
"import numpy as np\n",
|
| 17 |
+
"from pathlib import Path\n",
|
| 18 |
+
"from ctgan import CTGAN\n",
|
| 19 |
+
"import re\n",
|
| 20 |
+
"from typing import Dict, Any\n",
|
| 21 |
+
"import warnings\n",
|
| 22 |
+
"warnings.filterwarnings('ignore')\n"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": 3,
|
| 28 |
+
"id": "98679ba3",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"outputs": [
|
| 31 |
+
{
|
| 32 |
+
"name": "stdout",
|
| 33 |
+
"output_type": "stream",
|
| 34 |
+
"text": [
|
| 35 |
+
"์ด 216๊ฐ์ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์์ต๋๋ค.\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"์ฒ์ 5๊ฐ ํ์ผ ์์:\n",
|
| 38 |
+
" - ctgan_only_10000_1_busan_class0.pkl\n",
|
| 39 |
+
" - ctgan_only_10000_1_busan_class1.pkl\n",
|
| 40 |
+
" - ctgan_only_10000_1_daegu_class0.pkl\n",
|
| 41 |
+
" - ctgan_only_10000_1_daegu_class1.pkl\n",
|
| 42 |
+
" - ctgan_only_10000_1_daejeon_class0.pkl\n"
|
| 43 |
+
]
|
| 44 |
+
}
|
| 45 |
+
],
|
| 46 |
+
"source": [
|
| 47 |
+
"# ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์ \n",
|
| 48 |
+
"model_dir = Path(\"../save_model/oversampling_models\")\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# ๋ชจ๋ธ ํ์ผ ๋ชฉ๋ก ํ์ธ\n",
|
| 51 |
+
"model_files = sorted(list(model_dir.glob(\"*.pkl\")))\n",
|
| 52 |
+
"print(f\"์ด {len(model_files)}๊ฐ์ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์์ต๋๋ค.\")\n",
|
| 53 |
+
"print(f\"\\n์ฒ์ 5๊ฐ ํ์ผ ์์:\")\n",
|
| 54 |
+
"for f in model_files[:5]:\n",
|
| 55 |
+
" print(f\" - {f.name}\")\n"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 4,
|
| 61 |
+
"id": "97cde9e3",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [
|
| 64 |
+
{
|
| 65 |
+
"name": "stdout",
|
| 66 |
+
"output_type": "stream",
|
| 67 |
+
"text": [
|
| 68 |
+
"CTGAN ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ:\n",
|
| 69 |
+
" embedding_dim: 64\n",
|
| 70 |
+
" generator_dim: (64, 64)\n",
|
| 71 |
+
" discriminator_dim: (128, 128)\n",
|
| 72 |
+
" batch_size: 256\n",
|
| 73 |
+
" epochs: 300\n",
|
| 74 |
+
" pac: 8\n",
|
| 75 |
+
" discriminator_steps: 2\n",
|
| 76 |
+
" generator_lr: 0.0002\n",
|
| 77 |
+
" discriminator_lr: 0.0002\n",
|
| 78 |
+
" generator_decay: 1e-06\n",
|
| 79 |
+
" discriminator_decay: 1e-06\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"๋์
๋๋ฆฌ ํํ:\n",
|
| 82 |
+
"{'embedding_dim': 64, 'generator_dim': (64, 64), 'discriminator_dim': (128, 128), 'batch_size': 256, 'epochs': 300, 'pac': 8, 'discriminator_steps': 2, 'generator_lr': 0.0002, 'discriminator_lr': 0.0002, 'generator_decay': 1e-06, 'discriminator_decay': 1e-06}\n"
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
],
|
| 86 |
+
"source": [
|
| 87 |
+
"# CTGAN ๋ชจ๋ธ ๋ก๋ ๋ฐ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ธ ์์ \n",
|
| 88 |
+
"model = CTGAN.load(\"../save_model/oversampling_models/ctgan_only_10000_1_busan_class0.pkl\")\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# CTGAN ๋ชจ๋ธ์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ ๋ด๋ถ ์์ฑ(_๋ก ์์)์ ์ ์ฅ๋์ด ์์ต๋๋ค\n",
|
| 91 |
+
"print(\"CTGAN ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ:\")\n",
|
| 92 |
+
"print(f\" embedding_dim: {model._embedding_dim}\")\n",
|
| 93 |
+
"print(f\" generator_dim: {model._generator_dim}\")\n",
|
| 94 |
+
"print(f\" discriminator_dim: {model._discriminator_dim}\")\n",
|
| 95 |
+
"print(f\" batch_size: {model._batch_size}\")\n",
|
| 96 |
+
"print(f\" epochs: {model._epochs}\")\n",
|
| 97 |
+
"print(f\" pac: {model.pac}\") # pac๋ ๊ณต๊ฐ ์์ฑ์ผ๋ก๋ ์ ๊ทผ ๊ฐ๋ฅ\n",
|
| 98 |
+
"print(f\" discriminator_steps: {model._discriminator_steps}\")\n",
|
| 99 |
+
"print(f\" generator_lr: {model._generator_lr}\")\n",
|
| 100 |
+
"print(f\" discriminator_lr: {model._discriminator_lr}\")\n",
|
| 101 |
+
"print(f\" generator_decay: {model._generator_decay}\")\n",
|
| 102 |
+
"print(f\" discriminator_decay: {model._discriminator_decay}\")\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"# ๋ชจ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ๋์
๋๋ฆฌ๋ก ์ถ์ถํ๋ ๋ฐฉ๋ฒ\n",
|
| 105 |
+
"hyperparams = {\n",
|
| 106 |
+
" 'embedding_dim': model._embedding_dim,\n",
|
| 107 |
+
" 'generator_dim': model._generator_dim,\n",
|
| 108 |
+
" 'discriminator_dim': model._discriminator_dim,\n",
|
| 109 |
+
" 'batch_size': model._batch_size,\n",
|
| 110 |
+
" 'epochs': model._epochs,\n",
|
| 111 |
+
" 'pac': model.pac,\n",
|
| 112 |
+
" 'discriminator_steps': model._discriminator_steps,\n",
|
| 113 |
+
" 'generator_lr': model._generator_lr,\n",
|
| 114 |
+
" 'discriminator_lr': model._discriminator_lr,\n",
|
| 115 |
+
" 'generator_decay': model._generator_decay,\n",
|
| 116 |
+
" 'discriminator_decay': model._discriminator_decay,\n",
|
| 117 |
+
"}\n",
|
| 118 |
+
"print(\"\\n๋์
๋๋ฆฌ ํํ:\")\n",
|
| 119 |
+
"print(hyperparams)"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"cell_type": "code",
|
| 124 |
+
"execution_count": 5,
|
| 125 |
+
"id": "e3631f3b",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [
|
| 128 |
+
{
|
| 129 |
+
"name": "stdout",
|
| 130 |
+
"output_type": "stream",
|
| 131 |
+
"text": [
|
| 132 |
+
"ํ
์คํธ ํ์ผ: ctgan_only_10000_1_busan_class0.pkl\n",
|
| 133 |
+
"ํ์ฑ ๊ฒฐ๊ณผ: {'method': 'ctgan', 'sample_size': 10000, 'fold': 1, 'region': 'busan', 'class': 0}\n",
|
| 134 |
+
"ํ์ดํผํ๋ผ๋ฏธํฐ: {'embedding_dim': 64, 'generator_dim': '(64, 64)', 'discriminator_dim': '(128, 128)', 'pac': 8, 'batch_size': 256, 'discriminator_steps': 2, 'epochs': 300, 'generator_lr': 0.0002, 'discriminator_lr': 0.0002, 'generator_decay': 1e-06, 'discriminator_decay': 1e-06}\n"
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"source": [
|
| 139 |
+
"def parse_model_filename(filename: str) -> Dict[str, Any]:\n",
|
| 140 |
+
" \"\"\"\n",
|
| 141 |
+
" ๋ชจ๋ธ ํ์ผ๋ช
์์ ์ ๋ณด๋ฅผ ํ์ฑํฉ๋๋ค.\n",
|
| 142 |
+
" \n",
|
| 143 |
+
" ํ์ผ๋ช
ํจํด:\n",
|
| 144 |
+
" - ctgan_only_{sample_size}_{fold}_{region}_class{0|1}.pkl\n",
|
| 145 |
+
" - smotenc_ctgan_{sample_size}_{fold}_{region}_class{0|1}.pkl\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" Returns:\n",
|
| 148 |
+
" ํ์ฑ๋ ์ ๋ณด ๋์
๋๋ฆฌ\n",
|
| 149 |
+
" \"\"\"\n",
|
| 150 |
+
" # ํ์ผ๋ช
์์ ํ์ฅ์ ์ ๊ฑฐ\n",
|
| 151 |
+
" name = filename.replace('.pkl', '')\n",
|
| 152 |
+
" \n",
|
| 153 |
+
" # ํจํด ๋งค์นญ\n",
|
| 154 |
+
" if name.startswith('ctgan_only_'):\n",
|
| 155 |
+
" method = 'ctgan'\n",
|
| 156 |
+
" parts = name.replace('ctgan_only_', '').split('_')\n",
|
| 157 |
+
" elif name.startswith('smotenc_ctgan_'):\n",
|
| 158 |
+
" method = 'smotenc_ctgan'\n",
|
| 159 |
+
" parts = name.replace('smotenc_ctgan_', '').split('_')\n",
|
| 160 |
+
" else:\n",
|
| 161 |
+
" return None\n",
|
| 162 |
+
" \n",
|
| 163 |
+
" # sample_size, fold, region, class ์ถ์ถ\n",
|
| 164 |
+
" sample_size = int(parts[0])\n",
|
| 165 |
+
" fold = int(parts[1])\n",
|
| 166 |
+
" region = parts[2]\n",
|
| 167 |
+
" class_label = int(parts[3].replace('class', ''))\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" return {\n",
|
| 170 |
+
" 'method': method,\n",
|
| 171 |
+
" 'sample_size': sample_size,\n",
|
| 172 |
+
" 'fold': fold,\n",
|
| 173 |
+
" 'region': region,\n",
|
| 174 |
+
" 'class': class_label\n",
|
| 175 |
+
" }\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"def extract_hyperparameters(model_path: Path) -> Dict[str, Any]:\n",
|
| 179 |
+
" \"\"\"\n",
|
| 180 |
+
" CTGAN ๋ชจ๋ธ์์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ถ์ถํฉ๋๋ค.\n",
|
| 181 |
+
" \n",
|
| 182 |
+
" CTGAN ๋ชจ๋ธ์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ ๋ด๋ถ ์์ฑ(_๋ก ์์)์ ์ ์ฅ๋์ด ์์ต๋๋ค:\n",
|
| 183 |
+
" - _embedding_dim: ์๋ฒ ๋ฉ ์ฐจ์\n",
|
| 184 |
+
" - _generator_dim: ์์ฑ๊ธฐ ๋คํธ์ํฌ ์ฐจ์ (ํํ)\n",
|
| 185 |
+
" - _discriminator_dim: ํ๋ณ๊ธฐ ๋คํธ์ํฌ ์ฐจ์ (ํํ)\n",
|
| 186 |
+
" - _batch_size: ๋ฐฐ์น ํฌ๊ธฐ\n",
|
| 187 |
+
" - _epochs: ์ํฌํฌ ์\n",
|
| 188 |
+
" - _pac: PAC ํ๋ผ๋ฏธํฐ (๋๋ pac ์์ฑ์ผ๋ก ์ ๊ทผ ๊ฐ๋ฅ)\n",
|
| 189 |
+
" - _generator_lr: ์์ฑ๊ธฐ ํ์ต๋ฅ \n",
|
| 190 |
+
" - _discriminator_lr: ํ๋ณ๊ธฐ ํ์ต๋ฅ \n",
|
| 191 |
+
" - _discriminator_steps: ํ๋ณ๊ธฐ ์
๋ฐ์ดํธ ์คํ
์\n",
|
| 192 |
+
" \n",
|
| 193 |
+
" Args:\n",
|
| 194 |
+
" model_path: ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก\n",
|
| 195 |
+
" \n",
|
| 196 |
+
" Returns:\n",
|
| 197 |
+
" ํ์ดํผํ๋ผ๋ฏธํฐ ๋์
๋๋ฆฌ\n",
|
| 198 |
+
" \"\"\"\n",
|
| 199 |
+
" try:\n",
|
| 200 |
+
" # ๋ชจ๋ธ ๋ก๋\n",
|
| 201 |
+
" model = CTGAN.load(str(model_path))\n",
|
| 202 |
+
" \n",
|
| 203 |
+
" # ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ (๋ด๋ถ ์์ฑ ์ฌ์ฉ)\n",
|
| 204 |
+
" hyperparams = {\n",
|
| 205 |
+
" 'embedding_dim': getattr(model, '_embedding_dim', None),\n",
|
| 206 |
+
" 'generator_dim': str(getattr(model, '_generator_dim', None)), # ํํ์ ๋ฌธ์์ด๋ก ๋ณํ\n",
|
| 207 |
+
" 'discriminator_dim': str(getattr(model, '_discriminator_dim', None)), # ํํ์ ๋ฌธ์์ด๋ก ๋ณํ\n",
|
| 208 |
+
" 'pac': getattr(model, 'pac', None) or getattr(model, '_pac', None), # pac ์์ฑ ๋๋ _pac ์์ฑ\n",
|
| 209 |
+
" 'batch_size': getattr(model, '_batch_size', None),\n",
|
| 210 |
+
" 'discriminator_steps': getattr(model, '_discriminator_steps', None),\n",
|
| 211 |
+
" 'epochs': getattr(model, '_epochs', None),\n",
|
| 212 |
+
" 'generator_lr': getattr(model, '_generator_lr', None),\n",
|
| 213 |
+
" 'discriminator_lr': getattr(model, '_discriminator_lr', None),\n",
|
| 214 |
+
" 'generator_decay': getattr(model, '_generator_decay', None),\n",
|
| 215 |
+
" 'discriminator_decay': getattr(model, '_discriminator_decay', None),\n",
|
| 216 |
+
" }\n",
|
| 217 |
+
" \n",
|
| 218 |
+
" return hyperparams\n",
|
| 219 |
+
" except Exception as e:\n",
|
| 220 |
+
" print(f\"Error loading {model_path.name}: {str(e)}\")\n",
|
| 221 |
+
" import traceback\n",
|
| 222 |
+
" print(traceback.format_exc())\n",
|
| 223 |
+
" return None\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"# ํ
์คํธ: ์ฒซ ๋ฒ์งธ ๋ชจ๋ธ ํ์ผ๋ก ํ
์คํธ\n",
|
| 227 |
+
"if len(model_files) > 0:\n",
|
| 228 |
+
" test_file = model_files[0]\n",
|
| 229 |
+
" print(f\"ํ
์คํธ ํ์ผ: {test_file.name}\")\n",
|
| 230 |
+
" parsed = parse_model_filename(test_file.name)\n",
|
| 231 |
+
" print(f\"ํ์ฑ ๊ฒฐ๊ณผ: {parsed}\")\n",
|
| 232 |
+
" hyperparams = extract_hyperparameters(test_file)\n",
|
| 233 |
+
" print(f\"ํ์ดํผํ๋ผ๋ฏธํฐ: {hyperparams}\")\n"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "code",
|
| 238 |
+
"execution_count": 6,
|
| 239 |
+
"id": "9fc03ebe",
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"outputs": [
|
| 242 |
+
{
|
| 243 |
+
"name": "stdout",
|
| 244 |
+
"output_type": "stream",
|
| 245 |
+
"text": [
|
| 246 |
+
"๋ชจ๋ ๋ชจ๋ธ ํ์ผ์์ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ ์ค...\n",
|
| 247 |
+
"================================================================================\n",
|
| 248 |
+
"[20/216] ์งํ ์ค... (20๊ฐ ์ฑ๊ณต)\n",
|
| 249 |
+
"[40/216] ์งํ ์ค... (40๊ฐ ์ฑ๊ณต)\n",
|
| 250 |
+
"[60/216] ์งํ ์ค... (60๊ฐ ์ฑ๊ณต)\n",
|
| 251 |
+
"[80/216] ์งํ ์ค... (80๊ฐ ์ฑ๊ณต)\n",
|
| 252 |
+
"[100/216] ์งํ ์ค... (100๊ฐ ์ฑ๊ณต)\n",
|
| 253 |
+
"[120/216] ์งํ ์ค... (120๊ฐ ์ฑ๊ณต)\n",
|
| 254 |
+
"[140/216] ์งํ ์ค... (140๊ฐ ์ฑ๊ณต)\n",
|
| 255 |
+
"[160/216] ์งํ ์ค... (160๊ฐ ์ฑ๊ณต)\n",
|
| 256 |
+
"[180/216] ์งํ ์ค... (180๊ฐ ์ฑ๊ณต)\n",
|
| 257 |
+
"[200/216] ์งํ ์ค... (200๊ฐ ์ฑ๊ณต)\n",
|
| 258 |
+
"================================================================================\n",
|
| 259 |
+
"์๋ฃ! ์ด 216๊ฐ์ ๋ชจ๋ธ์์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ถ์ถํ์ต๋๋ค.\n"
|
| 260 |
+
]
|
| 261 |
+
}
|
| 262 |
+
],
|
| 263 |
+
"source": [
|
| 264 |
+
"# ๋ชจ๋ ๋ชจ๋ธ ํ์ผ์์ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ\n",
|
| 265 |
+
"all_results = []\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"print(\"๋ชจ๋ ๋ชจ๋ธ ํ์ผ์์ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ ์ค...\")\n",
|
| 268 |
+
"print(\"=\" * 80)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"for i, model_file in enumerate(model_files, 1):\n",
|
| 271 |
+
" # ํ์ผ๋ช
ํ์ฑ\n",
|
| 272 |
+
" parsed_info = parse_model_filename(model_file.name)\n",
|
| 273 |
+
" if parsed_info is None:\n",
|
| 274 |
+
" print(f\"[{i}/{len(model_files)}] ์คํต: {model_file.name} (ํ์ผ๋ช
ํจํด ๋ถ์ผ์น)\")\n",
|
| 275 |
+
" continue\n",
|
| 276 |
+
" \n",
|
| 277 |
+
" # ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ์ถ\n",
|
| 278 |
+
" hyperparams = extract_hyperparameters(model_file)\n",
|
| 279 |
+
" if hyperparams is None:\n",
|
| 280 |
+
" print(f\"[{i}/{len(model_files)}] ์คํจ: {model_file.name}\")\n",
|
| 281 |
+
" continue\n",
|
| 282 |
+
" \n",
|
| 283 |
+
" # ์ ๋ณด ํฉ์น๊ธฐ\n",
|
| 284 |
+
" result = {**parsed_info, **hyperparams}\n",
|
| 285 |
+
" result['filename'] = model_file.name\n",
|
| 286 |
+
" all_results.append(result)\n",
|
| 287 |
+
" \n",
|
| 288 |
+
" if i % 20 == 0:\n",
|
| 289 |
+
" print(f\"[{i}/{len(model_files)}] ์งํ ์ค... ({len(all_results)}๊ฐ ์ฑ๊ณต)\")\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"print(\"=\" * 80)\n",
|
| 292 |
+
"print(f\"์๋ฃ! ์ด {len(all_results)}๊ฐ์ ๋ชจ๋ธ์์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ถ์ถํ์ต๋๋ค.\")\n"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"cell_type": "code",
|
| 297 |
+
"execution_count": 7,
|
| 298 |
+
"id": "223e2b49",
|
| 299 |
+
"metadata": {},
|
| 300 |
+
"outputs": [
|
| 301 |
+
{
|
| 302 |
+
"name": "stdout",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
"์ด 216๊ฐ์ ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ์ ๋ฆฌ๋์์ต๋๋ค.\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"์ปฌ๋ผ: ['method', 'sample_size', 'fold', 'region', 'class', 'embedding_dim', 'generator_dim', 'discriminator_dim', 'pac', 'batch_size', 'discriminator_steps', 'epochs', 'generator_lr', 'discriminator_lr', 'filename']\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"์ฒ์ 5๊ฐ ํ:\n"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"data": {
|
| 314 |
+
"text/html": [
|
| 315 |
+
"<div>\n",
|
| 316 |
+
"<style scoped>\n",
|
| 317 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 318 |
+
" vertical-align: middle;\n",
|
| 319 |
+
" }\n",
|
| 320 |
+
"\n",
|
| 321 |
+
" .dataframe tbody tr th {\n",
|
| 322 |
+
" vertical-align: top;\n",
|
| 323 |
+
" }\n",
|
| 324 |
+
"\n",
|
| 325 |
+
" .dataframe thead th {\n",
|
| 326 |
+
" text-align: right;\n",
|
| 327 |
+
" }\n",
|
| 328 |
+
"</style>\n",
|
| 329 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 330 |
+
" <thead>\n",
|
| 331 |
+
" <tr style=\"text-align: right;\">\n",
|
| 332 |
+
" <th></th>\n",
|
| 333 |
+
" <th>method</th>\n",
|
| 334 |
+
" <th>sample_size</th>\n",
|
| 335 |
+
" <th>fold</th>\n",
|
| 336 |
+
" <th>region</th>\n",
|
| 337 |
+
" <th>class</th>\n",
|
| 338 |
+
" <th>embedding_dim</th>\n",
|
| 339 |
+
" <th>generator_dim</th>\n",
|
| 340 |
+
" <th>discriminator_dim</th>\n",
|
| 341 |
+
" <th>pac</th>\n",
|
| 342 |
+
" <th>batch_size</th>\n",
|
| 343 |
+
" <th>discriminator_steps</th>\n",
|
| 344 |
+
" <th>epochs</th>\n",
|
| 345 |
+
" <th>generator_lr</th>\n",
|
| 346 |
+
" <th>discriminator_lr</th>\n",
|
| 347 |
+
" <th>filename</th>\n",
|
| 348 |
+
" </tr>\n",
|
| 349 |
+
" </thead>\n",
|
| 350 |
+
" <tbody>\n",
|
| 351 |
+
" <tr>\n",
|
| 352 |
+
" <th>0</th>\n",
|
| 353 |
+
" <td>ctgan</td>\n",
|
| 354 |
+
" <td>7000</td>\n",
|
| 355 |
+
" <td>1</td>\n",
|
| 356 |
+
" <td>busan</td>\n",
|
| 357 |
+
" <td>0</td>\n",
|
| 358 |
+
" <td>78</td>\n",
|
| 359 |
+
" <td>(128, 128)</td>\n",
|
| 360 |
+
" <td>(128, 128)</td>\n",
|
| 361 |
+
" <td>8</td>\n",
|
| 362 |
+
" <td>256</td>\n",
|
| 363 |
+
" <td>3</td>\n",
|
| 364 |
+
" <td>300</td>\n",
|
| 365 |
+
" <td>0.0002</td>\n",
|
| 366 |
+
" <td>0.0002</td>\n",
|
| 367 |
+
" <td>ctgan_only_7000_1_busan_class0.pkl</td>\n",
|
| 368 |
+
" </tr>\n",
|
| 369 |
+
" <tr>\n",
|
| 370 |
+
" <th>1</th>\n",
|
| 371 |
+
" <td>ctgan</td>\n",
|
| 372 |
+
" <td>7000</td>\n",
|
| 373 |
+
" <td>1</td>\n",
|
| 374 |
+
" <td>busan</td>\n",
|
| 375 |
+
" <td>1</td>\n",
|
| 376 |
+
" <td>269</td>\n",
|
| 377 |
+
" <td>(256, 256)</td>\n",
|
| 378 |
+
" <td>(128, 128)</td>\n",
|
| 379 |
+
" <td>4</td>\n",
|
| 380 |
+
" <td>1024</td>\n",
|
| 381 |
+
" <td>1</td>\n",
|
| 382 |
+
" <td>300</td>\n",
|
| 383 |
+
" <td>0.0002</td>\n",
|
| 384 |
+
" <td>0.0002</td>\n",
|
| 385 |
+
" <td>ctgan_only_7000_1_busan_class1.pkl</td>\n",
|
| 386 |
+
" </tr>\n",
|
| 387 |
+
" <tr>\n",
|
| 388 |
+
" <th>2</th>\n",
|
| 389 |
+
" <td>ctgan</td>\n",
|
| 390 |
+
" <td>7000</td>\n",
|
| 391 |
+
" <td>1</td>\n",
|
| 392 |
+
" <td>daegu</td>\n",
|
| 393 |
+
" <td>0</td>\n",
|
| 394 |
+
" <td>121</td>\n",
|
| 395 |
+
" <td>(128, 128)</td>\n",
|
| 396 |
+
" <td>(64, 64)</td>\n",
|
| 397 |
+
" <td>4</td>\n",
|
| 398 |
+
" <td>64</td>\n",
|
| 399 |
+
" <td>2</td>\n",
|
| 400 |
+
" <td>300</td>\n",
|
| 401 |
+
" <td>0.0002</td>\n",
|
| 402 |
+
" <td>0.0002</td>\n",
|
| 403 |
+
" <td>ctgan_only_7000_1_daegu_class0.pkl</td>\n",
|
| 404 |
+
" </tr>\n",
|
| 405 |
+
" <tr>\n",
|
| 406 |
+
" <th>3</th>\n",
|
| 407 |
+
" <td>ctgan</td>\n",
|
| 408 |
+
" <td>7000</td>\n",
|
| 409 |
+
" <td>1</td>\n",
|
| 410 |
+
" <td>daegu</td>\n",
|
| 411 |
+
" <td>1</td>\n",
|
| 412 |
+
" <td>217</td>\n",
|
| 413 |
+
" <td>(128, 128)</td>\n",
|
| 414 |
+
" <td>(128, 128)</td>\n",
|
| 415 |
+
" <td>4</td>\n",
|
| 416 |
+
" <td>256</td>\n",
|
| 417 |
+
" <td>5</td>\n",
|
| 418 |
+
" <td>300</td>\n",
|
| 419 |
+
" <td>0.0002</td>\n",
|
| 420 |
+
" <td>0.0002</td>\n",
|
| 421 |
+
" <td>ctgan_only_7000_1_daegu_class1.pkl</td>\n",
|
| 422 |
+
" </tr>\n",
|
| 423 |
+
" <tr>\n",
|
| 424 |
+
" <th>4</th>\n",
|
| 425 |
+
" <td>ctgan</td>\n",
|
| 426 |
+
" <td>7000</td>\n",
|
| 427 |
+
" <td>1</td>\n",
|
| 428 |
+
" <td>daejeon</td>\n",
|
| 429 |
+
" <td>0</td>\n",
|
| 430 |
+
" <td>101</td>\n",
|
| 431 |
+
" <td>(128, 128)</td>\n",
|
| 432 |
+
" <td>(128, 128)</td>\n",
|
| 433 |
+
" <td>4</td>\n",
|
| 434 |
+
" <td>128</td>\n",
|
| 435 |
+
" <td>2</td>\n",
|
| 436 |
+
" <td>300</td>\n",
|
| 437 |
+
" <td>0.0002</td>\n",
|
| 438 |
+
" <td>0.0002</td>\n",
|
| 439 |
+
" <td>ctgan_only_7000_1_daejeon_class0.pkl</td>\n",
|
| 440 |
+
" </tr>\n",
|
| 441 |
+
" </tbody>\n",
|
| 442 |
+
"</table>\n",
|
| 443 |
+
"</div>"
|
| 444 |
+
],
|
| 445 |
+
"text/plain": [
|
| 446 |
+
" method sample_size fold region class embedding_dim generator_dim \\\n",
|
| 447 |
+
"0 ctgan 7000 1 busan 0 78 (128, 128) \n",
|
| 448 |
+
"1 ctgan 7000 1 busan 1 269 (256, 256) \n",
|
| 449 |
+
"2 ctgan 7000 1 daegu 0 121 (128, 128) \n",
|
| 450 |
+
"3 ctgan 7000 1 daegu 1 217 (128, 128) \n",
|
| 451 |
+
"4 ctgan 7000 1 daejeon 0 101 (128, 128) \n",
|
| 452 |
+
"\n",
|
| 453 |
+
" discriminator_dim pac batch_size discriminator_steps epochs \\\n",
|
| 454 |
+
"0 (128, 128) 8 256 3 300 \n",
|
| 455 |
+
"1 (128, 128) 4 1024 1 300 \n",
|
| 456 |
+
"2 (64, 64) 4 64 2 300 \n",
|
| 457 |
+
"3 (128, 128) 4 256 5 300 \n",
|
| 458 |
+
"4 (128, 128) 4 128 2 300 \n",
|
| 459 |
+
"\n",
|
| 460 |
+
" generator_lr discriminator_lr filename \n",
|
| 461 |
+
"0 0.0002 0.0002 ctgan_only_7000_1_busan_class0.pkl \n",
|
| 462 |
+
"1 0.0002 0.0002 ctgan_only_7000_1_busan_class1.pkl \n",
|
| 463 |
+
"2 0.0002 0.0002 ctgan_only_7000_1_daegu_class0.pkl \n",
|
| 464 |
+
"3 0.0002 0.0002 ctgan_only_7000_1_daegu_class1.pkl \n",
|
| 465 |
+
"4 0.0002 0.0002 ctgan_only_7000_1_daejeon_class0.pkl "
|
| 466 |
+
]
|
| 467 |
+
},
|
| 468 |
+
"execution_count": 7,
|
| 469 |
+
"metadata": {},
|
| 470 |
+
"output_type": "execute_result"
|
| 471 |
+
}
|
| 472 |
+
],
|
| 473 |
+
"source": [
|
| 474 |
+
"# DataFrame์ผ๋ก ๋ณํ\n",
|
| 475 |
+
"df_hyperparams = pd.DataFrame(all_results)\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"# ์ปฌ๋ผ ์์ ์ ๋ฆฌ\n",
|
| 478 |
+
"column_order = [\n",
|
| 479 |
+
" 'method', 'sample_size', 'fold', 'region', 'class',\n",
|
| 480 |
+
" 'embedding_dim', 'generator_dim', 'discriminator_dim',\n",
|
| 481 |
+
" 'pac', 'batch_size', 'discriminator_steps',\n",
|
| 482 |
+
" 'epochs', 'generator_lr', 'discriminator_lr',\n",
|
| 483 |
+
" 'filename'\n",
|
| 484 |
+
"]\n",
|
| 485 |
+
"df_hyperparams = df_hyperparams[column_order]\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"# ์ ๋ ฌ: method -> sample_size -> fold -> region -> class\n",
|
| 488 |
+
"df_hyperparams = df_hyperparams.sort_values(\n",
|
| 489 |
+
" ['method', 'sample_size', 'fold', 'region', 'class']\n",
|
| 490 |
+
").reset_index(drop=True)\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"print(f\"์ด {len(df_hyperparams)}๊ฐ์ ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ์ ๋ฆฌ๋์์ต๋๋ค.\")\n",
|
| 493 |
+
"print(f\"\\n์ปฌ๋ผ: {list(df_hyperparams.columns)}\")\n",
|
| 494 |
+
"print(f\"\\n์ฒ์ 5๊ฐ ํ:\")\n",
|
| 495 |
+
"df_hyperparams.head()\n"
|
| 496 |
+
]
|
| 497 |
+
},
|
| 498 |
+
{
|
| 499 |
+
"cell_type": "code",
|
| 500 |
+
"execution_count": 17,
|
| 501 |
+
"id": "9d3a8a65",
|
| 502 |
+
"metadata": {},
|
| 503 |
+
"outputs": [],
|
| 504 |
+
"source": [
|
| 505 |
+
"df_hyperparams.sort_values(by=['region','method','sample_size','fold','class'], inplace=True)"
|
| 506 |
+
]
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"cell_type": "code",
|
| 510 |
+
"execution_count": 24,
|
| 511 |
+
"id": "f92f352e",
|
| 512 |
+
"metadata": {},
|
| 513 |
+
"outputs": [
|
| 514 |
+
{
|
| 515 |
+
"name": "stdout",
|
| 516 |
+
"output_type": "stream",
|
| 517 |
+
"text": [
|
| 518 |
+
"ํ์ดํผํ๋ผ๋ฏธํฐ ๋ฐ์ดํฐ๊ฐ 'oversampling_models_hyperparameters_all.csv'์ ์ ์ฅ๋์์ต๋๋ค.\n"
|
| 519 |
+
]
|
| 520 |
+
}
|
| 521 |
+
],
|
| 522 |
+
"source": [
|
| 523 |
+
"# CSV๋ก ์ ์ฅ (์ ํ์ฌํญ)\n",
|
| 524 |
+
"output_csv = \"oversampling_models_hyperparameters_all.csv\"\n",
|
| 525 |
+
"df_hyperparams.to_csv(output_csv, index=False, encoding='utf-8-sig')\n",
|
| 526 |
+
"print(f\"ํ์ดํผํ๋ผ๋ฏธํฐ ๋ฐ์ดํฐ๊ฐ '{output_csv}'์ ์ ์ฅ๋์์ต๋๋ค.\")"
|
| 527 |
+
]
|
| 528 |
+
},
|
| 529 |
+
{
|
| 530 |
+
"cell_type": "code",
|
| 531 |
+
"execution_count": 25,
|
| 532 |
+
"id": "8ee1c56a",
|
| 533 |
+
"metadata": {},
|
| 534 |
+
"outputs": [
|
| 535 |
+
{
|
| 536 |
+
"data": {
|
| 537 |
+
"text/plain": [
|
| 538 |
+
"ctgan 108\n",
|
| 539 |
+
"smotenc_ctgan 108\n",
|
| 540 |
+
"Name: method, dtype: int64"
|
| 541 |
+
]
|
| 542 |
+
},
|
| 543 |
+
"execution_count": 25,
|
| 544 |
+
"metadata": {},
|
| 545 |
+
"output_type": "execute_result"
|
| 546 |
+
}
|
| 547 |
+
],
|
| 548 |
+
"source": [
|
| 549 |
+
"df_hyperparams['method'].value_counts()"
|
| 550 |
+
]
|
| 551 |
+
}
|
| 552 |
+
],
|
| 553 |
+
"metadata": {
|
| 554 |
+
"kernelspec": {
|
| 555 |
+
"display_name": "py39",
|
| 556 |
+
"language": "python",
|
| 557 |
+
"name": "python3"
|
| 558 |
+
},
|
| 559 |
+
"language_info": {
|
| 560 |
+
"codemirror_mode": {
|
| 561 |
+
"name": "ipython",
|
| 562 |
+
"version": 3
|
| 563 |
+
},
|
| 564 |
+
"file_extension": ".py",
|
| 565 |
+
"mimetype": "text/x-python",
|
| 566 |
+
"name": "python",
|
| 567 |
+
"nbconvert_exporter": "python",
|
| 568 |
+
"pygments_lexer": "ipython3",
|
| 569 |
+
"version": "3.9.18"
|
| 570 |
+
}
|
| 571 |
+
},
|
| 572 |
+
"nbformat": 4,
|
| 573 |
+
"nbformat_minor": 5
|
| 574 |
+
}
|
Analysis_code/4.sampling_data_test/analysis.ipynb
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "70effd7a",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd\n",
|
| 11 |
+
"import numpy as np"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 3,
|
| 17 |
+
"id": "f38ce7d1",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"df= pd.read_csv(\"../../data/oversampled_data_test_for_model/combined_sampled_data_test.csv\")"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": 4,
|
| 27 |
+
"id": "2bae91e4",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [
|
| 30 |
+
{
|
| 31 |
+
"data": {
|
| 32 |
+
"text/html": [
|
| 33 |
+
"<div>\n",
|
| 34 |
+
"<style scoped>\n",
|
| 35 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 36 |
+
" vertical-align: middle;\n",
|
| 37 |
+
" }\n",
|
| 38 |
+
"\n",
|
| 39 |
+
" .dataframe tbody tr th {\n",
|
| 40 |
+
" vertical-align: top;\n",
|
| 41 |
+
" }\n",
|
| 42 |
+
"\n",
|
| 43 |
+
" .dataframe thead th {\n",
|
| 44 |
+
" text-align: right;\n",
|
| 45 |
+
" }\n",
|
| 46 |
+
"</style>\n",
|
| 47 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 48 |
+
" <thead>\n",
|
| 49 |
+
" <tr style=\"text-align: right;\">\n",
|
| 50 |
+
" <th></th>\n",
|
| 51 |
+
" <th>region</th>\n",
|
| 52 |
+
" <th>model</th>\n",
|
| 53 |
+
" <th>data_sample</th>\n",
|
| 54 |
+
" <th>CSI</th>\n",
|
| 55 |
+
" <th>MCC</th>\n",
|
| 56 |
+
" <th>Accuracy</th>\n",
|
| 57 |
+
" <th>fold_csi</th>\n",
|
| 58 |
+
" </tr>\n",
|
| 59 |
+
" </thead>\n",
|
| 60 |
+
" <tbody>\n",
|
| 61 |
+
" <tr>\n",
|
| 62 |
+
" <th>0</th>\n",
|
| 63 |
+
" <td>seoul</td>\n",
|
| 64 |
+
" <td>LightGBM</td>\n",
|
| 65 |
+
" <td>pure</td>\n",
|
| 66 |
+
" <td>0.505041</td>\n",
|
| 67 |
+
" <td>0.646992</td>\n",
|
| 68 |
+
" <td>0.936174</td>\n",
|
| 69 |
+
" <td>[[0.46595932802825235, 0.5771195097037204, 0.4...</td>\n",
|
| 70 |
+
" </tr>\n",
|
| 71 |
+
" <tr>\n",
|
| 72 |
+
" <th>1</th>\n",
|
| 73 |
+
" <td>busan</td>\n",
|
| 74 |
+
" <td>LightGBM</td>\n",
|
| 75 |
+
" <td>pure</td>\n",
|
| 76 |
+
" <td>0.430188</td>\n",
|
| 77 |
+
" <td>0.600801</td>\n",
|
| 78 |
+
" <td>0.956971</td>\n",
|
| 79 |
+
" <td>[[0.32824427480911017, 0.4782608695651431, 0.4...</td>\n",
|
| 80 |
+
" </tr>\n",
|
| 81 |
+
" </tbody>\n",
|
| 82 |
+
"</table>\n",
|
| 83 |
+
"</div>"
|
| 84 |
+
],
|
| 85 |
+
"text/plain": [
|
| 86 |
+
" region model data_sample CSI MCC Accuracy \\\n",
|
| 87 |
+
"0 seoul LightGBM pure 0.505041 0.646992 0.936174 \n",
|
| 88 |
+
"1 busan LightGBM pure 0.430188 0.600801 0.956971 \n",
|
| 89 |
+
"\n",
|
| 90 |
+
" fold_csi \n",
|
| 91 |
+
"0 [[0.46595932802825235, 0.5771195097037204, 0.4... \n",
|
| 92 |
+
"1 [[0.32824427480911017, 0.4782608695651431, 0.4... "
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
"execution_count": 4,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"output_type": "execute_result"
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"source": [
|
| 101 |
+
"df.head(2)"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "code",
|
| 106 |
+
"execution_count": 5,
|
| 107 |
+
"id": "6893a958",
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [
|
| 110 |
+
{
|
| 111 |
+
"data": {
|
| 112 |
+
"text/html": [
|
| 113 |
+
"<div>\n",
|
| 114 |
+
"<style scoped>\n",
|
| 115 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 116 |
+
" vertical-align: middle;\n",
|
| 117 |
+
" }\n",
|
| 118 |
+
"\n",
|
| 119 |
+
" .dataframe tbody tr th {\n",
|
| 120 |
+
" vertical-align: top;\n",
|
| 121 |
+
" }\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" .dataframe thead th {\n",
|
| 124 |
+
" text-align: right;\n",
|
| 125 |
+
" }\n",
|
| 126 |
+
"</style>\n",
|
| 127 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 128 |
+
" <thead>\n",
|
| 129 |
+
" <tr style=\"text-align: right;\">\n",
|
| 130 |
+
" <th></th>\n",
|
| 131 |
+
" <th>region</th>\n",
|
| 132 |
+
" <th>model</th>\n",
|
| 133 |
+
" <th>data_sample</th>\n",
|
| 134 |
+
" <th>CSI</th>\n",
|
| 135 |
+
" </tr>\n",
|
| 136 |
+
" </thead>\n",
|
| 137 |
+
" <tbody>\n",
|
| 138 |
+
" <tr>\n",
|
| 139 |
+
" <th>0</th>\n",
|
| 140 |
+
" <td>busan</td>\n",
|
| 141 |
+
" <td>LightGBM</td>\n",
|
| 142 |
+
" <td>ctgan10000</td>\n",
|
| 143 |
+
" <td>0.467663</td>\n",
|
| 144 |
+
" </tr>\n",
|
| 145 |
+
" <tr>\n",
|
| 146 |
+
" <th>1</th>\n",
|
| 147 |
+
" <td>daegu</td>\n",
|
| 148 |
+
" <td>XGBoost</td>\n",
|
| 149 |
+
" <td>smote</td>\n",
|
| 150 |
+
" <td>0.454066</td>\n",
|
| 151 |
+
" </tr>\n",
|
| 152 |
+
" <tr>\n",
|
| 153 |
+
" <th>2</th>\n",
|
| 154 |
+
" <td>daejeon</td>\n",
|
| 155 |
+
" <td>LightGBM</td>\n",
|
| 156 |
+
" <td>smote</td>\n",
|
| 157 |
+
" <td>0.521335</td>\n",
|
| 158 |
+
" </tr>\n",
|
| 159 |
+
" <tr>\n",
|
| 160 |
+
" <th>3</th>\n",
|
| 161 |
+
" <td>gwangju</td>\n",
|
| 162 |
+
" <td>LightGBM</td>\n",
|
| 163 |
+
" <td>smote</td>\n",
|
| 164 |
+
" <td>0.522731</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <th>4</th>\n",
|
| 168 |
+
" <td>incheon</td>\n",
|
| 169 |
+
" <td>XGBoost</td>\n",
|
| 170 |
+
" <td>smote</td>\n",
|
| 171 |
+
" <td>0.589146</td>\n",
|
| 172 |
+
" </tr>\n",
|
| 173 |
+
" <tr>\n",
|
| 174 |
+
" <th>5</th>\n",
|
| 175 |
+
" <td>seoul</td>\n",
|
| 176 |
+
" <td>XGBoost</td>\n",
|
| 177 |
+
" <td>smote</td>\n",
|
| 178 |
+
" <td>0.582266</td>\n",
|
| 179 |
+
" </tr>\n",
|
| 180 |
+
" </tbody>\n",
|
| 181 |
+
"</table>\n",
|
| 182 |
+
"</div>"
|
| 183 |
+
],
|
| 184 |
+
"text/plain": [
|
| 185 |
+
" region model data_sample CSI\n",
|
| 186 |
+
"0 busan LightGBM ctgan10000 0.467663\n",
|
| 187 |
+
"1 daegu XGBoost smote 0.454066\n",
|
| 188 |
+
"2 daejeon LightGBM smote 0.521335\n",
|
| 189 |
+
"3 gwangju LightGBM smote 0.522731\n",
|
| 190 |
+
"4 incheon XGBoost smote 0.589146\n",
|
| 191 |
+
"5 seoul XGBoost smote 0.582266"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
"execution_count": 5,
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"output_type": "execute_result"
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"source": [
|
| 200 |
+
"# ์ง์ญ๋ณ๋ก CSI๊ฐ ๊ฐ์ฅ ๋์ model๊ณผ data_sample ์กฐํฉ ๋ณด๊ธฐ\n",
|
| 201 |
+
"top_csi_per_region = df.loc[df.groupby('region')['CSI'].idxmax()][['region', 'model', 'data_sample', 'CSI']]\n",
|
| 202 |
+
"top_csi_per_region = top_csi_per_region.sort_values('region').reset_index(drop=True)\n",
|
| 203 |
+
"top_csi_per_region"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"id": "2942ba86",
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": []
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"id": "d55af59c",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": []
|
| 221 |
+
}
|
| 222 |
+
],
|
| 223 |
+
"metadata": {
|
| 224 |
+
"kernelspec": {
|
| 225 |
+
"display_name": "py39",
|
| 226 |
+
"language": "python",
|
| 227 |
+
"name": "python3"
|
| 228 |
+
},
|
| 229 |
+
"language_info": {
|
| 230 |
+
"codemirror_mode": {
|
| 231 |
+
"name": "ipython",
|
| 232 |
+
"version": 3
|
| 233 |
+
},
|
| 234 |
+
"file_extension": ".py",
|
| 235 |
+
"mimetype": "text/x-python",
|
| 236 |
+
"name": "python",
|
| 237 |
+
"nbconvert_exporter": "python",
|
| 238 |
+
"pygments_lexer": "ipython3",
|
| 239 |
+
"version": "3.9.18"
|
| 240 |
+
}
|
| 241 |
+
},
|
| 242 |
+
"nbformat": 4,
|
| 243 |
+
"nbformat_minor": 5
|
| 244 |
+
}
|
Analysis_code/4.sampling_data_test/lgb_sampled_test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analysis_code/4.sampling_data_test/xgb_sampled_test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_busan.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 16 |
+
def print_trial_callback(study, trial):
|
| 17 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 18 |
+
print(f"\n{'='*80}")
|
| 19 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 20 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 21 |
+
print(f" Parameters: {trial.params}")
|
| 22 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 23 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 24 |
+
print(f" Best Parameters: {study.best_params}")
|
| 25 |
+
print(f"{'='*80}\n")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 29 |
+
study = optuna.create_study(
|
| 30 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 31 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="busan"),
|
| 37 |
+
n_trials=100,
|
| 38 |
+
callbacks=[print_trial_callback]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 42 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 43 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 44 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 48 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 49 |
+
|
| 50 |
+
if len(csi_scores) > 0:
|
| 51 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 52 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 53 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 54 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 55 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 56 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 57 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 58 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 59 |
+
|
| 60 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 61 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 62 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 63 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 64 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 65 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_busan_trials.pkl")
|
| 66 |
+
joblib.dump(study, study_path)
|
| 67 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 68 |
+
|
| 69 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 70 |
+
print("\n" + "="*50)
|
| 71 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 72 |
+
print("="*50)
|
| 73 |
+
|
| 74 |
+
best_params = study.best_params
|
| 75 |
+
model_paths = train_final_model(
|
| 76 |
+
best_params=best_params,
|
| 77 |
+
model_choose="deepgbm",
|
| 78 |
+
region="busan",
|
| 79 |
+
data_sample='pure',
|
| 80 |
+
target='multi',
|
| 81 |
+
n_folds=3,
|
| 82 |
+
random_state=seed
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 86 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 87 |
+
for path in model_paths:
|
| 88 |
+
print(f" - {path}")
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 92 |
+
import traceback
|
| 93 |
+
traceback.print_exc()
|
| 94 |
+
|
| 95 |
+
# ์ ์ ์ข
๋ฃ
|
| 96 |
+
import sys
|
| 97 |
+
sys.exit(0)
|
| 98 |
+
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_daegu.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="daegu"),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_daegu_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_paths = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="daegu",
|
| 80 |
+
data_sample='pure',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 88 |
+
for path in model_paths:
|
| 89 |
+
print(f" - {path}")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 93 |
+
import traceback
|
| 94 |
+
traceback.print_exc()
|
| 95 |
+
|
| 96 |
+
# ์ ์ ์ข
๋ฃ
|
| 97 |
+
import sys
|
| 98 |
+
sys.exit(0)
|
| 99 |
+
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_daejeon.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="daejeon"),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_daejeon_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_paths = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="daejeon",
|
| 80 |
+
data_sample='pure',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 88 |
+
for path in model_paths:
|
| 89 |
+
print(f" - {path}")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 93 |
+
import traceback
|
| 94 |
+
traceback.print_exc()
|
| 95 |
+
|
| 96 |
+
# ์ ์ ์ข
๋ฃ
|
| 97 |
+
import sys
|
| 98 |
+
sys.exit(0)
|
| 99 |
+
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_gwangju.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="gwangju"),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_gwangju_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_paths = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="gwangju",
|
| 80 |
+
data_sample='pure',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 88 |
+
for path in model_paths:
|
| 89 |
+
print(f" - {path}")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 93 |
+
import traceback
|
| 94 |
+
traceback.print_exc()
|
| 95 |
+
|
| 96 |
+
# ์ ์ ์ข
๋ฃ
|
| 97 |
+
import sys
|
| 98 |
+
sys.exit(0)
|
| 99 |
+
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_incheon.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="incheon"),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_incheon_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_paths = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="incheon",
|
| 80 |
+
data_sample='pure',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 88 |
+
for path in model_paths:
|
| 89 |
+
print(f" - {path}")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 93 |
+
import traceback
|
| 94 |
+
traceback.print_exc()
|
| 95 |
+
|
| 96 |
+
# ์ ์ ์ข
๋ฃ
|
| 97 |
+
import sys
|
| 98 |
+
sys.exit(0)
|
| 99 |
+
|
Analysis_code/5.optima/deepgbm_pure/deepgbm_pure_seoul.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="seoul"),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_pure_seoul_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_paths = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="seoul",
|
| 80 |
+
data_sample='pure',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก:")
|
| 88 |
+
for path in model_paths:
|
| 89 |
+
print(f" - {path}")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 93 |
+
import traceback
|
| 94 |
+
traceback.print_exc()
|
| 95 |
+
|
| 96 |
+
# ์ ์ ์ข
๋ฃ
|
| 97 |
+
import sys
|
| 98 |
+
sys.exit(0)
|
| 99 |
+
|
Analysis_code/5.optima/deepgbm_pure/utils.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import copy
|
| 8 |
+
from sklearn.preprocessing import QuantileTransformer, LabelEncoder
|
| 9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 10 |
+
from sklearn.metrics import confusion_matrix
|
| 11 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import optuna
|
| 14 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 15 |
+
import joblib
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก models ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 20 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
+
models_path = os.path.abspath(os.path.join(current_file_dir, '../../models'))
|
| 22 |
+
sys.path.insert(0, models_path)
|
| 23 |
+
from ft_transformer import FTTransformer
|
| 24 |
+
from resnet_like import ResNetLike
|
| 25 |
+
from deepgbm import DeepGBM
|
| 26 |
+
import warnings
|
| 27 |
+
warnings.filterwarnings('ignore')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 31 |
+
seed = 42
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
|
| 35 |
+
# PyTorch ์๋ ๊ณ ์
|
| 36 |
+
torch.manual_seed(seed)
|
| 37 |
+
torch.cuda.manual_seed(seed)
|
| 38 |
+
torch.cuda.manual_seed_all(seed) # Multi-GPU ํ๊ฒฝ์์ ๋์ผํ ์๋ ์ ์ฉ
|
| 39 |
+
|
| 40 |
+
# PyTorch ์ฐ์ฐ์ ๊ฒฐ์ ์ ๋ชจ๋ ์ค์
|
| 41 |
+
torch.backends.cudnn.deterministic = True # ์คํ๋ง๋ค ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅ
|
| 42 |
+
torch.backends.cudnn.benchmark = True # ์ฑ๋ฅ ์ต์ ํ๋ฅผ ํ์ฑํ (๊ฐ๋ฅํ ํ ๋น ๋ฅธ ์ฐ์ฐ ์ํ)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 46 |
+
"""
|
| 47 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 54 |
+
"""
|
| 55 |
+
df = df.copy()
|
| 56 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 57 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 58 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 59 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 60 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 61 |
+
return df
|
| 62 |
+
|
| 63 |
+
def preprocessing(df):
|
| 64 |
+
"""๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ํจ์.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
df: ์๋ณธ ๋ฐ์ดํฐํ๋ ์
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐํ๋ ์
|
| 71 |
+
"""
|
| 72 |
+
df = df[df.columns].copy()
|
| 73 |
+
df['year'] = df['year'].astype('int')
|
| 74 |
+
df['month'] = df['month'].astype('int')
|
| 75 |
+
df['hour'] = df['hour'].astype('int')
|
| 76 |
+
df = add_derived_features(df).copy()
|
| 77 |
+
df['multi_class'] = df['multi_class'].astype('int')
|
| 78 |
+
df.loc[df['wind_dir']=='์ ์จ', 'wind_dir'] = "0"
|
| 79 |
+
df['wind_dir'] = df['wind_dir'].astype('int')
|
| 80 |
+
df = df[['temp_C', 'precip_mm', 'wind_speed', 'wind_dir', 'hm',
|
| 81 |
+
'vap_pressure', 'dewpoint_C', 'loc_pressure', 'sea_pressure',
|
| 82 |
+
'solarRad', 'snow_cm', 'cloudcover', 'lm_cloudcover', 'low_cloudbase',
|
| 83 |
+
'groundtemp', 'O3', 'NO2', 'PM10', 'PM25', 'year',
|
| 84 |
+
'month', 'hour', 'ground_temp - temp_C', 'hour_sin', 'hour_cos',
|
| 85 |
+
'month_sin', 'month_cos','multi_class']].copy()
|
| 86 |
+
return df
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ๋ฐ์ดํฐ์
์ค๋น ํจ์
|
| 90 |
+
def prepare_dataset(region, data_sample='pure', target='multi', fold=3):
|
| 91 |
+
|
| 92 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 93 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 94 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 95 |
+
|
| 96 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 97 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 98 |
+
if data_sample == 'pure':
|
| 99 |
+
train_path = dat_path
|
| 100 |
+
else:
|
| 101 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 102 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 103 |
+
drop_col = ['multi_class','year']
|
| 104 |
+
target_col = f'{target}_class'
|
| 105 |
+
|
| 106 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 107 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 108 |
+
if data_sample == 'pure':
|
| 109 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 110 |
+
else:
|
| 111 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 112 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 113 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 114 |
+
|
| 115 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 116 |
+
common_columns = region_train.columns.to_list()
|
| 117 |
+
train_data = region_train[common_columns]
|
| 118 |
+
val_data = region_val[common_columns]
|
| 119 |
+
test_data = region_test[common_columns]
|
| 120 |
+
|
| 121 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 122 |
+
X_train = train_data.drop(columns=drop_col)
|
| 123 |
+
y_train = train_data[target_col]
|
| 124 |
+
X_val = val_data.drop(columns=drop_col)
|
| 125 |
+
y_val = val_data[target_col]
|
| 126 |
+
X_test = test_data.drop(columns=drop_col)
|
| 127 |
+
y_test = test_data[target_col]
|
| 128 |
+
|
| 129 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 130 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 131 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 132 |
+
|
| 133 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 134 |
+
label_encoders = {}
|
| 135 |
+
for col in categorical_cols:
|
| 136 |
+
le = LabelEncoder()
|
| 137 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 138 |
+
label_encoders[col] = le
|
| 139 |
+
|
| 140 |
+
# ๋ณํ ์ ์ฉ
|
| 141 |
+
for col in categorical_cols:
|
| 142 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 143 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 144 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 145 |
+
|
| 146 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 147 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 148 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 149 |
+
|
| 150 |
+
# ๋ณํ ์ ์ฉ
|
| 151 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 152 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 153 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 154 |
+
|
| 155 |
+
return X_train, X_val, X_test, y_train, y_val, y_test, categorical_cols, numerical_cols
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ๋ฐ์ดํฐ ๋ณํ ๋ฐ dataloader ์์ฑ ํจ์
|
| 160 |
+
def prepare_dataloader(region, data_sample='pure', target='multi', fold=3, random_state=None):
|
| 161 |
+
|
| 162 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 163 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 164 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 165 |
+
|
| 166 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 167 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 168 |
+
if data_sample == 'pure':
|
| 169 |
+
train_path = dat_path
|
| 170 |
+
else:
|
| 171 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 172 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 173 |
+
drop_col = ['multi_class','year']
|
| 174 |
+
target_col = f'{target}_class'
|
| 175 |
+
|
| 176 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 177 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 178 |
+
if data_sample == 'pure':
|
| 179 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 180 |
+
else:
|
| 181 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 182 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 183 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 184 |
+
|
| 185 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 186 |
+
common_columns = region_train.columns.to_list()
|
| 187 |
+
train_data = region_train[common_columns]
|
| 188 |
+
val_data = region_val[common_columns]
|
| 189 |
+
test_data = region_test[common_columns]
|
| 190 |
+
|
| 191 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 192 |
+
X_train = train_data.drop(columns=drop_col)
|
| 193 |
+
y_train = train_data[target_col]
|
| 194 |
+
X_val = val_data.drop(columns=drop_col)
|
| 195 |
+
y_val = val_data[target_col]
|
| 196 |
+
X_test = test_data.drop(columns=drop_col)
|
| 197 |
+
y_test = test_data[target_col]
|
| 198 |
+
|
| 199 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 200 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 201 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 202 |
+
|
| 203 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 204 |
+
label_encoders = {}
|
| 205 |
+
for col in categorical_cols:
|
| 206 |
+
le = LabelEncoder()
|
| 207 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 208 |
+
label_encoders[col] = le
|
| 209 |
+
|
| 210 |
+
# ๋ณํ ์ ์ฉ
|
| 211 |
+
for col in categorical_cols:
|
| 212 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 213 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 214 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 215 |
+
|
| 216 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 217 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 218 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 219 |
+
|
| 220 |
+
# ๋ณํ ์ ์ฉ
|
| 221 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 222 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 223 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 224 |
+
|
| 225 |
+
# ์ฐ์ํ ๋ณ์์ ๋ฒ์ฃผํ ๋ณ์ ๋ถ๋ฆฌ
|
| 226 |
+
X_train_num = torch.tensor(X_train[numerical_cols].values, dtype=torch.float32)
|
| 227 |
+
X_train_cat = torch.tensor(X_train[categorical_cols].values, dtype=torch.long)
|
| 228 |
+
|
| 229 |
+
X_val_num = torch.tensor(X_val[numerical_cols].values, dtype=torch.float32)
|
| 230 |
+
X_val_cat = torch.tensor(X_val[categorical_cols].values, dtype=torch.long)
|
| 231 |
+
|
| 232 |
+
X_test_num = torch.tensor(X_test[numerical_cols].values, dtype=torch.float32)
|
| 233 |
+
X_test_cat = torch.tensor(X_test[categorical_cols].values, dtype=torch.long)
|
| 234 |
+
|
| 235 |
+
# ๋ ์ด๋ธ ๋ณํ
|
| 236 |
+
if target == "binary":
|
| 237 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32) # ์ด์ง ๋ถ๋ฅ โ float32
|
| 238 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32)
|
| 239 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)
|
| 240 |
+
elif target == "multi":
|
| 241 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long) # ๋ค์ค ๋ถ๋ฅ โ long
|
| 242 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.long)
|
| 243 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("target must be 'binary' or 'multi'")
|
| 246 |
+
|
| 247 |
+
# TensorDataset ์์ฑ
|
| 248 |
+
train_dataset = TensorDataset(X_train_num, X_train_cat, y_train_tensor)
|
| 249 |
+
val_dataset = TensorDataset(X_val_num, X_val_cat, y_val_tensor)
|
| 250 |
+
test_dataset = TensorDataset(X_test_num, X_test_cat, y_test_tensor)
|
| 251 |
+
|
| 252 |
+
# DataLoader ์์ฑ
|
| 253 |
+
if random_state == None:
|
| 254 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
| 255 |
+
else:
|
| 256 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(random_state))
|
| 257 |
+
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
|
| 258 |
+
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
|
| 259 |
+
|
| 260 |
+
return X_train, categorical_cols, numerical_cols, train_loader, val_loader, test_loader
|
| 261 |
+
|
| 262 |
+
# ๋ฐ์ดํฐ ๋ณํ ๋ฐ dataloader ์์ฑ ํจ์ (batch_size ํ๋ผ๋ฏธํฐ ์ถ๊ฐ ๋ฒ์ )
|
| 263 |
+
def prepare_dataloader_with_batchsize(region, data_sample='pure', target='multi', fold=3, random_state=None, batch_size=64):
|
| 264 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 265 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 266 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 267 |
+
|
| 268 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 269 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 270 |
+
if data_sample == 'pure':
|
| 271 |
+
train_path = dat_path
|
| 272 |
+
else:
|
| 273 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 274 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 275 |
+
drop_col = ['multi_class','year']
|
| 276 |
+
target_col = f'{target}_class'
|
| 277 |
+
|
| 278 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 279 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 280 |
+
if data_sample == 'pure':
|
| 281 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 282 |
+
else:
|
| 283 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 284 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 285 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 286 |
+
|
| 287 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 288 |
+
common_columns = region_train.columns.to_list()
|
| 289 |
+
train_data = region_train[common_columns]
|
| 290 |
+
val_data = region_val[common_columns]
|
| 291 |
+
test_data = region_test[common_columns]
|
| 292 |
+
|
| 293 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 294 |
+
X_train = train_data.drop(columns=drop_col)
|
| 295 |
+
y_train = train_data[target_col]
|
| 296 |
+
X_val = val_data.drop(columns=drop_col)
|
| 297 |
+
y_val = val_data[target_col]
|
| 298 |
+
X_test = test_data.drop(columns=drop_col)
|
| 299 |
+
y_test = test_data[target_col]
|
| 300 |
+
|
| 301 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 302 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 303 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 304 |
+
|
| 305 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 306 |
+
label_encoders = {}
|
| 307 |
+
for col in categorical_cols:
|
| 308 |
+
le = LabelEncoder()
|
| 309 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 310 |
+
label_encoders[col] = le
|
| 311 |
+
|
| 312 |
+
# ๋ณํ ์ ์ฉ
|
| 313 |
+
for col in categorical_cols:
|
| 314 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 315 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 316 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 317 |
+
|
| 318 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 319 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 320 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 321 |
+
|
| 322 |
+
# ๋ณํ ์ ์ฉ
|
| 323 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 324 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 325 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 326 |
+
|
| 327 |
+
# ์ฐ์ํ ๋ณ์์ ๋ฒ์ฃผํ ๋ณ์ ๋ถ๋ฆฌ
|
| 328 |
+
X_train_num = torch.tensor(X_train[numerical_cols].values, dtype=torch.float32)
|
| 329 |
+
X_train_cat = torch.tensor(X_train[categorical_cols].values, dtype=torch.long)
|
| 330 |
+
|
| 331 |
+
X_val_num = torch.tensor(X_val[numerical_cols].values, dtype=torch.float32)
|
| 332 |
+
X_val_cat = torch.tensor(X_val[categorical_cols].values, dtype=torch.long)
|
| 333 |
+
|
| 334 |
+
X_test_num = torch.tensor(X_test[numerical_cols].values, dtype=torch.float32)
|
| 335 |
+
X_test_cat = torch.tensor(X_test[categorical_cols].values, dtype=torch.long)
|
| 336 |
+
|
| 337 |
+
# ๋ ์ด๋ธ ๋ณํ
|
| 338 |
+
if target == "binary":
|
| 339 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32) # ์ด์ง ๋ถ๋ฅ โ float32
|
| 340 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32)
|
| 341 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)
|
| 342 |
+
elif target == "multi":
|
| 343 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long) # ๋ค์ค ๋ถ๋ฅ โ long
|
| 344 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.long)
|
| 345 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError("target must be 'binary' or 'multi'")
|
| 348 |
+
|
| 349 |
+
# TensorDataset ์์ฑ
|
| 350 |
+
train_dataset = TensorDataset(X_train_num, X_train_cat, y_train_tensor)
|
| 351 |
+
val_dataset = TensorDataset(X_val_num, X_val_cat, y_val_tensor)
|
| 352 |
+
test_dataset = TensorDataset(X_test_num, X_test_cat, y_test_tensor)
|
| 353 |
+
|
| 354 |
+
# DataLoader ์์ฑ (batch_size ํ๋ผ๋ฏธํฐ ์ฌ์ฉ)
|
| 355 |
+
if random_state == None:
|
| 356 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 357 |
+
else:
|
| 358 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(random_state))
|
| 359 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 360 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 361 |
+
|
| 362 |
+
return X_train, categorical_cols, numerical_cols, train_loader, val_loader, test_loader, y_train, scaler
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def calculate_csi(y_true, pred):
|
| 366 |
+
|
| 367 |
+
cm = confusion_matrix(y_true, pred) # ๋ณ์ ์ด๋ฆ์ cm์ผ๋ก ๋ณ๊ฒฝ
|
| 368 |
+
# ํผ๋ ํ๋ ฌ์์ H, F, M ์ถ์ถ
|
| 369 |
+
H = (cm[0, 0] + cm[1, 1])
|
| 370 |
+
|
| 371 |
+
F = (cm[1, 0] + cm[2, 0] +
|
| 372 |
+
cm[0, 1] + cm[2, 1])
|
| 373 |
+
|
| 374 |
+
M = (cm[0, 2] + cm[1, 2])
|
| 375 |
+
|
| 376 |
+
# CSI ๊ณ์ฐ
|
| 377 |
+
CSI = H / (H + F + M + 1e-10)
|
| 378 |
+
return CSI
|
| 379 |
+
|
| 380 |
+
def sample_weight(y_train):
|
| 381 |
+
class_weights = compute_class_weight(
|
| 382 |
+
class_weight='balanced',
|
| 383 |
+
classes=np.unique(y_train), # ๊ณ ์ ํด๋์ค
|
| 384 |
+
y=y_train # ํ์ต ๋ฐ์ดํฐ ๋ ์ด๋ธ
|
| 385 |
+
)
|
| 386 |
+
sample_weights = np.array([class_weights[label] for label in y_train])
|
| 387 |
+
|
| 388 |
+
return sample_weights
|
| 389 |
+
|
| 390 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ์ต์ ํ ํจ์ ์ ์
|
| 391 |
+
def objective(trial, model_choose, region, data_sample='pure', target='multi', n_folds=3, random_state=42):
|
| 392 |
+
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ device ์ค์
|
| 393 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 394 |
+
val_scores = []
|
| 395 |
+
|
| 396 |
+
# --- 1. ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ ์ (์์ ๋จ) ---
|
| 397 |
+
if model_choose == "ft_transformer":
|
| 398 |
+
d_token = trial.suggest_int("d_token", 64, 256, step=32)
|
| 399 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 6) # ๊น์ด ์ถ์๋ก ๊ณผ์ ํฉ ๋ฐฉ์ง
|
| 400 |
+
n_heads = trial.suggest_categorical("n_heads", [4, 8])
|
| 401 |
+
# d_token์ n_heads์ ๋ฐฐ์์ฌ์ผ ํจ (FT-Transformer์ ๊ตฌ์กฐ์ ์ ์ฝ ๋์)
|
| 402 |
+
if d_token % n_heads != 0:
|
| 403 |
+
d_token = (d_token // n_heads) * n_heads
|
| 404 |
+
|
| 405 |
+
attention_dropout = trial.suggest_float("attention_dropout", 0.1, 0.4)
|
| 406 |
+
ffn_dropout = trial.suggest_float("ffn_dropout", 0.1, 0.4)
|
| 407 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # ๋ฒ์ ํ๋
|
| 408 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 409 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 410 |
+
|
| 411 |
+
elif model_choose == 'resnet_like':
|
| 412 |
+
d_main = trial.suggest_int("d_main", 64, 256, step=32)
|
| 413 |
+
d_hidden = trial.suggest_int("d_hidden", 64, 512, step=64)
|
| 414 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 5) # ๋๋ฌด ๊น์ง ์๊ฒ ์กฐ์
|
| 415 |
+
dropout_first = trial.suggest_float("dropout_first", 0.1, 0.4)
|
| 416 |
+
dropout_second = trial.suggest_float("dropout_second", 0.0, 0.2)
|
| 417 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
| 418 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 419 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 420 |
+
|
| 421 |
+
elif model_choose == 'deepgbm':
|
| 422 |
+
# DeepGBM์ ๊ฒฝ์ฐ ๋ชจ๋ธ ํน์ฑ์ ๋ง์ถฐ ResNet ๋ธ๋ก ๋ฐ ์๋ฒ ๋ฉ ์ฐจ์ ์กฐ์
|
| 423 |
+
d_main = trial.suggest_int("d_main", 64, 256, step=32)
|
| 424 |
+
d_hidden = trial.suggest_int("d_hidden", 64, 256, step=64)
|
| 425 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 6)
|
| 426 |
+
dropout = trial.suggest_float("dropout", 0.1, 0.4)
|
| 427 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
| 428 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 429 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 430 |
+
|
| 431 |
+
# --- 2. Fold๋ณ ํ์ต ๋ฐ ๊ต์ฐจ ๊ฒ์ฆ ---
|
| 432 |
+
for fold in range(1, n_folds + 1):
|
| 433 |
+
X_train_df, categorical_cols, numerical_cols, train_loader, val_loader, _, y_train, _ = prepare_dataloader_with_batchsize(
|
| 434 |
+
region, data_sample=data_sample, target=target, fold=fold, random_state=random_state, batch_size=batch_size
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
| 438 |
+
if model_choose == "ft_transformer":
|
| 439 |
+
model = FTTransformer(
|
| 440 |
+
num_features=len(numerical_cols),
|
| 441 |
+
cat_cardinalities=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 442 |
+
d_token=d_token,
|
| 443 |
+
n_blocks=n_blocks,
|
| 444 |
+
n_heads=n_heads,
|
| 445 |
+
attention_dropout=attention_dropout,
|
| 446 |
+
ffn_dropout=ffn_dropout,
|
| 447 |
+
num_classes=3
|
| 448 |
+
).to(device)
|
| 449 |
+
elif model_choose == 'resnet_like':
|
| 450 |
+
input_dim = len(numerical_cols) + len(categorical_cols)
|
| 451 |
+
model = ResNetLike(
|
| 452 |
+
input_dim=input_dim,
|
| 453 |
+
d_main=d_main,
|
| 454 |
+
d_hidden=d_hidden,
|
| 455 |
+
n_blocks=n_blocks,
|
| 456 |
+
dropout_first=dropout_first,
|
| 457 |
+
dropout_second=dropout_second,
|
| 458 |
+
num_classes=3
|
| 459 |
+
).to(device)
|
| 460 |
+
elif model_choose == 'deepgbm':
|
| 461 |
+
model = DeepGBM(
|
| 462 |
+
num_features=len(numerical_cols),
|
| 463 |
+
cat_features=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 464 |
+
d_main=d_main,
|
| 465 |
+
d_hidden=d_hidden,
|
| 466 |
+
n_blocks=n_blocks,
|
| 467 |
+
dropout=dropout,
|
| 468 |
+
num_classes=3
|
| 469 |
+
).to(device)
|
| 470 |
+
|
| 471 |
+
# ํด๋์ค ๊ฐ์ค์น ๊ณ์ฐ ๋ฐ ์์ค ํจ์ ์ค์ (Label Smoothing ์ ์ฉ)
|
| 472 |
+
if target == 'multi':
|
| 473 |
+
class_weights = compute_class_weight(
|
| 474 |
+
class_weight='balanced',
|
| 475 |
+
classes=np.unique(y_train),
|
| 476 |
+
y=y_train
|
| 477 |
+
)
|
| 478 |
+
# ํด๋์ค๋ณ ๊ฐ์ค์น ๋ก๊ทธ ์ถ๋ ฅ
|
| 479 |
+
unique_classes = np.unique(y_train)
|
| 480 |
+
class_counts = {cls: np.sum(y_train == cls) for cls in unique_classes}
|
| 481 |
+
print(f" Fold {fold} - ํด๋์ค๋ณ ๊ฐ์ค์น: {dict(zip(unique_classes, class_weights))} (ํด๋์ค๋ณ ์ํ ์: {class_counts})")
|
| 482 |
+
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
|
| 483 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.0) # Label Smoothing ์ถ๊ฐ
|
| 484 |
+
else:
|
| 485 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 486 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 487 |
+
|
| 488 |
+
# ํ์ต๋ฅ ์ค์ผ์ค๋ฌ ์ถ๊ฐ: ์ฑ๋ฅ ์ ์ฒด ์ LR์ 0.5๋ฐฐ ๊ฐ์ (๊ฒ์ฆ CSI ๊ธฐ์ค)
|
| 489 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
|
| 490 |
+
|
| 491 |
+
# ํ์ต ์ค์ (์ํญ ๋ฐ ํ์ด์
์ค ์ํฅ)
|
| 492 |
+
epochs = 200
|
| 493 |
+
patience = 12 # ๋ฅ๋ฌ๋์ ์ ์ฒด ๊ตฌ๊ฐ์ ๊ณ ๋ คํ์ฌ ์ํญ ์ํฅ
|
| 494 |
+
best_fold_csi = 0
|
| 495 |
+
counter = 0
|
| 496 |
+
|
| 497 |
+
for epoch in range(epochs):
|
| 498 |
+
model.train()
|
| 499 |
+
for x_num_batch, x_cat_batch, y_batch in train_loader:
|
| 500 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 501 |
+
|
| 502 |
+
optimizer.zero_grad()
|
| 503 |
+
y_pred = model(x_num_batch, x_cat_batch)
|
| 504 |
+
loss = criterion(y_pred, y_batch if target == 'multi' else y_batch.float())
|
| 505 |
+
loss.backward()
|
| 506 |
+
optimizer.step()
|
| 507 |
+
|
| 508 |
+
# Validation ํ๊ฐ
|
| 509 |
+
model.eval()
|
| 510 |
+
y_pred_val, y_true_val = [], []
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
for x_num_batch, x_cat_batch, y_batch in val_loader:
|
| 513 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 514 |
+
output = model(x_num_batch, x_cat_batch)
|
| 515 |
+
pred = output.argmax(dim=1) if target == 'multi' else (torch.sigmoid(output) >= 0.5).long()
|
| 516 |
+
|
| 517 |
+
y_pred_val.extend(pred.cpu().numpy())
|
| 518 |
+
y_true_val.extend(y_batch.cpu().numpy())
|
| 519 |
+
|
| 520 |
+
# CSI ๊ณ์ฐ ๋ฐ ์ค์ผ์ค๋ฌ ์
๋ฐ์ดํธ
|
| 521 |
+
val_csi = calculate_csi(y_true_val, y_pred_val)
|
| 522 |
+
scheduler.step(val_csi)
|
| 523 |
+
|
| 524 |
+
# Optuna Pruning ์ ์ฉ (์ฒซ ๋ฒ์งธ Fold์์ ์กฐ๊ธฐ ์ข
๋ฃ ํ๋จ ๊ฐํ)
|
| 525 |
+
trial.report(val_csi, epoch)
|
| 526 |
+
if trial.should_prune():
|
| 527 |
+
raise optuna.exceptions.TrialPruned()
|
| 528 |
+
|
| 529 |
+
# Early Stopping ์ฒดํฌ
|
| 530 |
+
if val_csi > best_fold_csi:
|
| 531 |
+
best_fold_csi = val_csi
|
| 532 |
+
counter = 0
|
| 533 |
+
else:
|
| 534 |
+
counter += 1
|
| 535 |
+
|
| 536 |
+
if counter >= patience:
|
| 537 |
+
break
|
| 538 |
+
|
| 539 |
+
val_scores.append(best_fold_csi)
|
| 540 |
+
|
| 541 |
+
# ๋ชจ๋ fold์ ํ๊ท ์ฑ๋ฅ ๋ฐํ
|
| 542 |
+
return np.mean(val_scores)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ํจ์
|
| 546 |
+
def train_final_model(best_params, model_choose, region, data_sample='pure', target='multi', n_folds=3, random_state=42):
|
| 547 |
+
"""
|
| 548 |
+
์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ์ ํ์ตํ๊ณ ์ ์ฅํฉ๋๋ค.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
best_params: ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ ๋์
๋๋ฆฌ
|
| 552 |
+
model_choose: ๋ชจ๋ธ ์ ํ ('ft_transformer', 'resnet_like', 'deepgbm')
|
| 553 |
+
region: ์ง์ญ๋ช
|
| 554 |
+
data_sample: ๋ฐ์ดํฐ ์ํ ํ์
('pure', 'smote', etc.)
|
| 555 |
+
target: ํ๊ฒ ํ์
('multi', 'binary')
|
| 556 |
+
n_folds: ๊ต์ฐจ ๊ฒ์ฆ fold ์
|
| 557 |
+
random_state: ๋๋ค ์๋
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก ๋ฆฌ์คํธ
|
| 561 |
+
"""
|
| 562 |
+
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ device ์ค์
|
| 563 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 564 |
+
|
| 565 |
+
models = []
|
| 566 |
+
scalers = [] # scaler ๋ฆฌ์คํธ ์ถ๊ฐ
|
| 567 |
+
|
| 568 |
+
print("์ต์ข
๋ชจ๋ธ ํ์ต ์์...")
|
| 569 |
+
|
| 570 |
+
for fold in range(1, n_folds + 1):
|
| 571 |
+
print(f"Fold {fold} ํ์ต ์ค...")
|
| 572 |
+
|
| 573 |
+
# ์ต์ ํ๋ batch_size ์ฌ์ฉ
|
| 574 |
+
batch_size = best_params.get("batch_size", 64)
|
| 575 |
+
X_train_df, categorical_cols, numerical_cols, train_loader, val_loader, _, y_train, scaler = prepare_dataloader_with_batchsize(
|
| 576 |
+
region, data_sample=data_sample, target=target, fold=fold, random_state=random_state, batch_size=batch_size
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
| 580 |
+
if model_choose == "ft_transformer":
|
| 581 |
+
d_token = best_params["d_token"]
|
| 582 |
+
n_heads = best_params.get("n_heads", 8)
|
| 583 |
+
# d_token์ n_heads์ ๋ฐฐ์์ฌ์ผ ํจ (FT-Transformer์ ๊ตฌ์กฐ์ ์ ์ฝ ๋์)
|
| 584 |
+
if d_token % n_heads != 0:
|
| 585 |
+
d_token = (d_token // n_heads) * n_heads
|
| 586 |
+
|
| 587 |
+
model = FTTransformer(
|
| 588 |
+
num_features=len(numerical_cols),
|
| 589 |
+
cat_cardinalities=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 590 |
+
d_token=d_token,
|
| 591 |
+
n_blocks=best_params["n_blocks"],
|
| 592 |
+
n_heads=n_heads,
|
| 593 |
+
attention_dropout=best_params["attention_dropout"],
|
| 594 |
+
ffn_dropout=best_params["ffn_dropout"],
|
| 595 |
+
num_classes=3
|
| 596 |
+
).to(device)
|
| 597 |
+
elif model_choose == 'resnet_like':
|
| 598 |
+
input_dim = len(numerical_cols) + len(categorical_cols)
|
| 599 |
+
model = ResNetLike(
|
| 600 |
+
input_dim=input_dim,
|
| 601 |
+
d_main=best_params["d_main"],
|
| 602 |
+
d_hidden=best_params["d_hidden"],
|
| 603 |
+
n_blocks=best_params["n_blocks"],
|
| 604 |
+
dropout_first=best_params["dropout_first"],
|
| 605 |
+
dropout_second=best_params["dropout_second"],
|
| 606 |
+
num_classes=3
|
| 607 |
+
).to(device)
|
| 608 |
+
elif model_choose == 'deepgbm':
|
| 609 |
+
model = DeepGBM(
|
| 610 |
+
num_features=len(numerical_cols),
|
| 611 |
+
cat_features=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 612 |
+
d_main=best_params["d_main"],
|
| 613 |
+
d_hidden=best_params["d_hidden"],
|
| 614 |
+
n_blocks=best_params["n_blocks"],
|
| 615 |
+
dropout=best_params["dropout"],
|
| 616 |
+
num_classes=3
|
| 617 |
+
).to(device)
|
| 618 |
+
else:
|
| 619 |
+
raise ValueError(f"Unknown model_choose: {model_choose}")
|
| 620 |
+
|
| 621 |
+
# ํด๋์ค ๊ฐ์ค์น ๊ณ์ฐ ๋ฐ ์์ค ํจ์ ์ค์ (Label Smoothing ์ ์ฉ)
|
| 622 |
+
if target == 'multi':
|
| 623 |
+
class_weights = compute_class_weight(
|
| 624 |
+
class_weight='balanced',
|
| 625 |
+
classes=np.unique(y_train),
|
| 626 |
+
y=y_train
|
| 627 |
+
)
|
| 628 |
+
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
|
| 629 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.0) # Label Smoothing ์ถ๊ฐ
|
| 630 |
+
else:
|
| 631 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 632 |
+
optimizer = optim.AdamW(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
|
| 633 |
+
|
| 634 |
+
# ํ์ต๋ฅ ์ค์ผ์ค๋ฌ
|
| 635 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
|
| 636 |
+
|
| 637 |
+
# ํ์ต ์ค์
|
| 638 |
+
epochs = 200
|
| 639 |
+
patience = 12
|
| 640 |
+
best_fold_csi = 0
|
| 641 |
+
counter = 0
|
| 642 |
+
best_model = None
|
| 643 |
+
|
| 644 |
+
for epoch in range(epochs):
|
| 645 |
+
model.train()
|
| 646 |
+
for x_num_batch, x_cat_batch, y_batch in train_loader:
|
| 647 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 648 |
+
|
| 649 |
+
optimizer.zero_grad()
|
| 650 |
+
y_pred = model(x_num_batch, x_cat_batch)
|
| 651 |
+
loss = criterion(y_pred, y_batch if target == 'multi' else y_batch.float())
|
| 652 |
+
loss.backward()
|
| 653 |
+
optimizer.step()
|
| 654 |
+
|
| 655 |
+
# Validation ํ๊ฐ
|
| 656 |
+
model.eval()
|
| 657 |
+
y_pred_val, y_true_val = [], []
|
| 658 |
+
with torch.no_grad():
|
| 659 |
+
for x_num_batch, x_cat_batch, y_batch in val_loader:
|
| 660 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 661 |
+
output = model(x_num_batch, x_cat_batch)
|
| 662 |
+
pred = output.argmax(dim=1) if target == 'multi' else (torch.sigmoid(output) >= 0.5).long()
|
| 663 |
+
|
| 664 |
+
y_pred_val.extend(pred.cpu().numpy())
|
| 665 |
+
y_true_val.extend(y_batch.cpu().numpy())
|
| 666 |
+
|
| 667 |
+
# CSI ๊ณ์ฐ ๋ฐ ์ค์ผ์ค๋ฌ ์
๋ฐ์ดํธ
|
| 668 |
+
val_csi = calculate_csi(y_true_val, y_pred_val)
|
| 669 |
+
scheduler.step(val_csi)
|
| 670 |
+
|
| 671 |
+
# Early Stopping ์ฒดํฌ
|
| 672 |
+
if val_csi > best_fold_csi:
|
| 673 |
+
best_fold_csi = val_csi
|
| 674 |
+
counter = 0
|
| 675 |
+
best_model = copy.deepcopy(model)
|
| 676 |
+
else:
|
| 677 |
+
counter += 1
|
| 678 |
+
|
| 679 |
+
if counter >= patience:
|
| 680 |
+
print(f" Early stopping at epoch {epoch+1}, Best CSI: {best_fold_csi:.4f}")
|
| 681 |
+
break
|
| 682 |
+
|
| 683 |
+
if best_model is None:
|
| 684 |
+
best_model = model
|
| 685 |
+
|
| 686 |
+
scalers.append(scaler) # scaler ์ ์ฅ (fold ์์๋๋ก)
|
| 687 |
+
models.append(best_model)
|
| 688 |
+
print(f" Fold {fold} ํ์ต ์๋ฃ (๊ฒ์ฆ CSI: {best_fold_csi:.4f})")
|
| 689 |
+
|
| 690 |
+
# ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก ์ค์
|
| 691 |
+
save_dir = f'../save_model/{model_choose}_optima'
|
| 692 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 693 |
+
|
| 694 |
+
# ํ์ผ๋ช
์์ฑ
|
| 695 |
+
if data_sample == 'pure':
|
| 696 |
+
model_filename = f'{model_choose}_pure_{region}.pkl'
|
| 697 |
+
else:
|
| 698 |
+
model_filename = f'{model_choose}_{data_sample}_{region}.pkl'
|
| 699 |
+
|
| 700 |
+
model_path = f'{save_dir}/{model_filename}'
|
| 701 |
+
|
| 702 |
+
# ๋ฆฌ์คํธ์ ๋ด์ ํ ๋ฒ์ ์ ์ฅ
|
| 703 |
+
joblib.dump(models, model_path)
|
| 704 |
+
print(f"\n๋ชจ๋ ๋ชจ๋ธ ์ ์ฅ ์๋ฃ: {model_path} (์ด {len(models)}๊ฐ fold)")
|
| 705 |
+
|
| 706 |
+
# Scaler ๋ณ๋ ์ ์ฅ
|
| 707 |
+
scaler_save_dir = f'../save_model/{model_choose}_optima/scaler'
|
| 708 |
+
os.makedirs(scaler_save_dir, exist_ok=True)
|
| 709 |
+
|
| 710 |
+
# ํ์ผ๋ช
์์ฑ (๋ชจ๋ธ๊ณผ ๋์ผํ ํจํด)
|
| 711 |
+
if data_sample == 'pure':
|
| 712 |
+
scaler_filename = f'{model_choose}_pure_{region}_scaler.pkl'
|
| 713 |
+
else:
|
| 714 |
+
scaler_filename = f'{model_choose}_{data_sample}_{region}_scaler.pkl'
|
| 715 |
+
|
| 716 |
+
scaler_path = f'{scaler_save_dir}/{scaler_filename}'
|
| 717 |
+
joblib.dump(scalers, scaler_path)
|
| 718 |
+
print(f"Scaler ์ ์ฅ ์๋ฃ: {scaler_path} (์ด {len(scalers)}๊ฐ fold)")
|
| 719 |
+
|
| 720 |
+
return model_path
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_busan.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="busan", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_busan_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="busan",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_daegu.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="daegu", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_daegu_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="daegu",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_daejeon.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="daejeon", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_daejeon_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="daejeon",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_gwangju.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="gwangju", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_gwangju_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="gwangju",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_incheon.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="incheon", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_incheon_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="incheon",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/deepgbm_smote_seoul.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="seoul", data_sample='smote'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smote_seoul_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="seoul",
|
| 80 |
+
data_sample='smote',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smote/utils.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import copy
|
| 8 |
+
from sklearn.preprocessing import QuantileTransformer, LabelEncoder
|
| 9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 10 |
+
from sklearn.metrics import confusion_matrix
|
| 11 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import optuna
|
| 14 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 15 |
+
import joblib
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก models ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 20 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
+
models_path = os.path.abspath(os.path.join(current_file_dir, '../../models'))
|
| 22 |
+
sys.path.insert(0, models_path)
|
| 23 |
+
from ft_transformer import FTTransformer
|
| 24 |
+
from resnet_like import ResNetLike
|
| 25 |
+
from deepgbm import DeepGBM
|
| 26 |
+
import warnings
|
| 27 |
+
warnings.filterwarnings('ignore')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 31 |
+
seed = 42
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
|
| 35 |
+
# PyTorch ์๋ ๊ณ ์
|
| 36 |
+
torch.manual_seed(seed)
|
| 37 |
+
torch.cuda.manual_seed(seed)
|
| 38 |
+
torch.cuda.manual_seed_all(seed) # Multi-GPU ํ๊ฒฝ์์ ๋์ผํ ์๋ ์ ์ฉ
|
| 39 |
+
|
| 40 |
+
# PyTorch ์ฐ์ฐ์ ๊ฒฐ์ ์ ๋ชจ๋ ์ค์
|
| 41 |
+
torch.backends.cudnn.deterministic = True # ์คํ๋ง๋ค ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅ
|
| 42 |
+
torch.backends.cudnn.benchmark = True # ์ฑ๋ฅ ์ต์ ํ๋ฅผ ํ์ฑํ (๊ฐ๋ฅํ ํ ๋น ๋ฅธ ์ฐ์ฐ ์ํ)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def add_derived_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 46 |
+
"""
|
| 47 |
+
์ ๊ฑฐํ๋ ํ์ ๋ณ์๋ค์ ๋ณต๊ตฌ
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
df: ๋ฐ์ดํฐํ๋ ์
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
ํ์ ๋ณ์๊ฐ ์ถ๊ฐ๋ ๋ฐ์ดํฐํ๋ ์
|
| 54 |
+
"""
|
| 55 |
+
df = df.copy()
|
| 56 |
+
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
|
| 57 |
+
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
|
| 58 |
+
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
| 59 |
+
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
| 60 |
+
df['ground_temp - temp_C'] = df['groundtemp'] - df['temp_C']
|
| 61 |
+
return df
|
| 62 |
+
|
| 63 |
+
def preprocessing(df):
|
| 64 |
+
"""๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ํจ์.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
df: ์๋ณธ ๋ฐ์ดํฐํ๋ ์
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐํ๋ ์
|
| 71 |
+
"""
|
| 72 |
+
df = df[df.columns].copy()
|
| 73 |
+
df['year'] = df['year'].astype('int')
|
| 74 |
+
df['month'] = df['month'].astype('int')
|
| 75 |
+
df['hour'] = df['hour'].astype('int')
|
| 76 |
+
df = add_derived_features(df).copy()
|
| 77 |
+
df['multi_class'] = df['multi_class'].astype('int')
|
| 78 |
+
df.loc[df['wind_dir']=='์ ์จ', 'wind_dir'] = "0"
|
| 79 |
+
df['wind_dir'] = df['wind_dir'].astype('int')
|
| 80 |
+
df = df[['temp_C', 'precip_mm', 'wind_speed', 'wind_dir', 'hm',
|
| 81 |
+
'vap_pressure', 'dewpoint_C', 'loc_pressure', 'sea_pressure',
|
| 82 |
+
'solarRad', 'snow_cm', 'cloudcover', 'lm_cloudcover', 'low_cloudbase',
|
| 83 |
+
'groundtemp', 'O3', 'NO2', 'PM10', 'PM25', 'year',
|
| 84 |
+
'month', 'hour', 'ground_temp - temp_C', 'hour_sin', 'hour_cos',
|
| 85 |
+
'month_sin', 'month_cos','multi_class']].copy()
|
| 86 |
+
return df
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ๋ฐ์ดํฐ์
์ค๋น ํจ์
|
| 90 |
+
def prepare_dataset(region, data_sample='pure', target='multi', fold=3):
|
| 91 |
+
|
| 92 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 93 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 94 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 95 |
+
|
| 96 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 97 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 98 |
+
if data_sample == 'pure':
|
| 99 |
+
train_path = dat_path
|
| 100 |
+
else:
|
| 101 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 102 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 103 |
+
drop_col = ['multi_class','year']
|
| 104 |
+
target_col = f'{target}_class'
|
| 105 |
+
|
| 106 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 107 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 108 |
+
if data_sample == 'pure':
|
| 109 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 110 |
+
else:
|
| 111 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 112 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 113 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 114 |
+
|
| 115 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 116 |
+
common_columns = region_train.columns.to_list()
|
| 117 |
+
train_data = region_train[common_columns]
|
| 118 |
+
val_data = region_val[common_columns]
|
| 119 |
+
test_data = region_test[common_columns]
|
| 120 |
+
|
| 121 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 122 |
+
X_train = train_data.drop(columns=drop_col)
|
| 123 |
+
y_train = train_data[target_col]
|
| 124 |
+
X_val = val_data.drop(columns=drop_col)
|
| 125 |
+
y_val = val_data[target_col]
|
| 126 |
+
X_test = test_data.drop(columns=drop_col)
|
| 127 |
+
y_test = test_data[target_col]
|
| 128 |
+
|
| 129 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 130 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 131 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 132 |
+
|
| 133 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 134 |
+
label_encoders = {}
|
| 135 |
+
for col in categorical_cols:
|
| 136 |
+
le = LabelEncoder()
|
| 137 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 138 |
+
label_encoders[col] = le
|
| 139 |
+
|
| 140 |
+
# ๋ณํ ์ ์ฉ
|
| 141 |
+
for col in categorical_cols:
|
| 142 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 143 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 144 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 145 |
+
|
| 146 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 147 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 148 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 149 |
+
|
| 150 |
+
# ๋ณํ ์ ์ฉ
|
| 151 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 152 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 153 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 154 |
+
|
| 155 |
+
return X_train, X_val, X_test, y_train, y_val, y_test, categorical_cols, numerical_cols
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ๋ฐ์ดํฐ ๋ณํ ๋ฐ dataloader ์์ฑ ํจ์
|
| 160 |
+
def prepare_dataloader(region, data_sample='pure', target='multi', fold=3, random_state=None):
|
| 161 |
+
|
| 162 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 163 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 164 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 165 |
+
|
| 166 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 167 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 168 |
+
if data_sample == 'pure':
|
| 169 |
+
train_path = dat_path
|
| 170 |
+
else:
|
| 171 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 172 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 173 |
+
drop_col = ['multi_class','year']
|
| 174 |
+
target_col = f'{target}_class'
|
| 175 |
+
|
| 176 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 177 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 178 |
+
if data_sample == 'pure':
|
| 179 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 180 |
+
else:
|
| 181 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 182 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 183 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 184 |
+
|
| 185 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 186 |
+
common_columns = region_train.columns.to_list()
|
| 187 |
+
train_data = region_train[common_columns]
|
| 188 |
+
val_data = region_val[common_columns]
|
| 189 |
+
test_data = region_test[common_columns]
|
| 190 |
+
|
| 191 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 192 |
+
X_train = train_data.drop(columns=drop_col)
|
| 193 |
+
y_train = train_data[target_col]
|
| 194 |
+
X_val = val_data.drop(columns=drop_col)
|
| 195 |
+
y_val = val_data[target_col]
|
| 196 |
+
X_test = test_data.drop(columns=drop_col)
|
| 197 |
+
y_test = test_data[target_col]
|
| 198 |
+
|
| 199 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 200 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 201 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 202 |
+
|
| 203 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 204 |
+
label_encoders = {}
|
| 205 |
+
for col in categorical_cols:
|
| 206 |
+
le = LabelEncoder()
|
| 207 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 208 |
+
label_encoders[col] = le
|
| 209 |
+
|
| 210 |
+
# ๋ณํ ์ ์ฉ
|
| 211 |
+
for col in categorical_cols:
|
| 212 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 213 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 214 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 215 |
+
|
| 216 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 217 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 218 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 219 |
+
|
| 220 |
+
# ๋ณํ ์ ์ฉ
|
| 221 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 222 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 223 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 224 |
+
|
| 225 |
+
# ์ฐ์ํ ๋ณ์์ ๋ฒ์ฃผํ ๋ณ์ ๋ถ๋ฆฌ
|
| 226 |
+
X_train_num = torch.tensor(X_train[numerical_cols].values, dtype=torch.float32)
|
| 227 |
+
X_train_cat = torch.tensor(X_train[categorical_cols].values, dtype=torch.long)
|
| 228 |
+
|
| 229 |
+
X_val_num = torch.tensor(X_val[numerical_cols].values, dtype=torch.float32)
|
| 230 |
+
X_val_cat = torch.tensor(X_val[categorical_cols].values, dtype=torch.long)
|
| 231 |
+
|
| 232 |
+
X_test_num = torch.tensor(X_test[numerical_cols].values, dtype=torch.float32)
|
| 233 |
+
X_test_cat = torch.tensor(X_test[categorical_cols].values, dtype=torch.long)
|
| 234 |
+
|
| 235 |
+
# ๋ ์ด๋ธ ๋ณํ
|
| 236 |
+
if target == "binary":
|
| 237 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32) # ์ด์ง ๋ถ๋ฅ โ float32
|
| 238 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32)
|
| 239 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)
|
| 240 |
+
elif target == "multi":
|
| 241 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long) # ๋ค์ค ๋ถ๋ฅ โ long
|
| 242 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.long)
|
| 243 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("target must be 'binary' or 'multi'")
|
| 246 |
+
|
| 247 |
+
# TensorDataset ์์ฑ
|
| 248 |
+
train_dataset = TensorDataset(X_train_num, X_train_cat, y_train_tensor)
|
| 249 |
+
val_dataset = TensorDataset(X_val_num, X_val_cat, y_val_tensor)
|
| 250 |
+
test_dataset = TensorDataset(X_test_num, X_test_cat, y_test_tensor)
|
| 251 |
+
|
| 252 |
+
# DataLoader ์์ฑ
|
| 253 |
+
if random_state == None:
|
| 254 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
| 255 |
+
else:
|
| 256 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(random_state))
|
| 257 |
+
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
|
| 258 |
+
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
|
| 259 |
+
|
| 260 |
+
return X_train, categorical_cols, numerical_cols, train_loader, val_loader, test_loader
|
| 261 |
+
|
| 262 |
+
# ๋ฐ์ดํฐ ๋ณํ ๋ฐ dataloader ์์ฑ ํจ์ (batch_size ํ๋ผ๋ฏธํฐ ์ถ๊ฐ ๋ฒ์ )
|
| 263 |
+
def prepare_dataloader_with_batchsize(region, data_sample='pure', target='multi', fold=3, random_state=None, batch_size=64):
|
| 264 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 265 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 266 |
+
data_base_dir = os.path.abspath(os.path.join(current_file_dir, '../../../data'))
|
| 267 |
+
|
| 268 |
+
# ๋ฐ์ดํฐ ๊ฒฝ๋ก ์ง์
|
| 269 |
+
dat_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_train.csv")
|
| 270 |
+
if data_sample == 'pure':
|
| 271 |
+
train_path = dat_path
|
| 272 |
+
else:
|
| 273 |
+
train_path = os.path.join(data_base_dir, f'data_oversampled/{data_sample}/{data_sample}_{fold}_{region}.csv')
|
| 274 |
+
test_path = os.path.join(data_base_dir, f"data_for_modeling/{region}_test.csv")
|
| 275 |
+
drop_col = ['multi_class','year']
|
| 276 |
+
target_col = f'{target}_class'
|
| 277 |
+
|
| 278 |
+
# ๋ฐ์ดํฐ ๋ก๋
|
| 279 |
+
region_dat = preprocessing(pd.read_csv(dat_path, index_col=0))
|
| 280 |
+
if data_sample == 'pure':
|
| 281 |
+
region_train = region_dat.loc[~region_dat['year'].isin([2021-fold]), :]
|
| 282 |
+
else:
|
| 283 |
+
region_train = preprocessing(pd.read_csv(train_path))
|
| 284 |
+
region_val = region_dat.loc[region_dat['year'].isin([2021-fold]), :]
|
| 285 |
+
region_test = preprocessing(pd.read_csv(test_path))
|
| 286 |
+
|
| 287 |
+
# ์ปฌ๋ผ ์ ๋ ฌ (์ผ๊ด์ฑ ์ ์ง)
|
| 288 |
+
common_columns = region_train.columns.to_list()
|
| 289 |
+
train_data = region_train[common_columns]
|
| 290 |
+
val_data = region_val[common_columns]
|
| 291 |
+
test_data = region_test[common_columns]
|
| 292 |
+
|
| 293 |
+
# ์ค๋ช
๋ณ์ & ํ๊ฒ ๋ถ๋ฆฌ
|
| 294 |
+
X_train = train_data.drop(columns=drop_col)
|
| 295 |
+
y_train = train_data[target_col]
|
| 296 |
+
X_val = val_data.drop(columns=drop_col)
|
| 297 |
+
y_val = val_data[target_col]
|
| 298 |
+
X_test = test_data.drop(columns=drop_col)
|
| 299 |
+
y_test = test_data[target_col]
|
| 300 |
+
|
| 301 |
+
# ๋ฒ์ฃผํ & ์ฐ์ํ ๋ณ์ ๋ถ๋ฆฌ
|
| 302 |
+
categorical_cols = X_train.select_dtypes(include=['object', 'category', 'int64']).columns
|
| 303 |
+
numerical_cols = X_train.select_dtypes(include=['float64']).columns
|
| 304 |
+
|
| 305 |
+
# ๋ฒ์ฃผํ ๋ณ์ Label Encoding
|
| 306 |
+
label_encoders = {}
|
| 307 |
+
for col in categorical_cols:
|
| 308 |
+
le = LabelEncoder()
|
| 309 |
+
le.fit(X_train[col]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 310 |
+
label_encoders[col] = le
|
| 311 |
+
|
| 312 |
+
# ๋ณํ ์ ์ฉ
|
| 313 |
+
for col in categorical_cols:
|
| 314 |
+
X_train[col] = label_encoders[col].transform(X_train[col])
|
| 315 |
+
X_val[col] = label_encoders[col].transform(X_val[col])
|
| 316 |
+
X_test[col] = label_encoders[col].transform(X_test[col])
|
| 317 |
+
|
| 318 |
+
# ์ฐ์ํ ๋ณ์ Quantile Transformation
|
| 319 |
+
scaler = QuantileTransformer(output_distribution='normal')
|
| 320 |
+
scaler.fit(X_train[numerical_cols]) # Train ๋ฐ์ดํฐ ๊ธฐ์ค์ผ๋ก ํ์ต
|
| 321 |
+
|
| 322 |
+
# ๋ณํ ์ ์ฉ
|
| 323 |
+
X_train[numerical_cols] = scaler.transform(X_train[numerical_cols])
|
| 324 |
+
X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
|
| 325 |
+
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])
|
| 326 |
+
|
| 327 |
+
# ์ฐ์ํ ๋ณ์์ ๋ฒ์ฃผํ ๋ณ์ ๋ถ๋ฆฌ
|
| 328 |
+
X_train_num = torch.tensor(X_train[numerical_cols].values, dtype=torch.float32)
|
| 329 |
+
X_train_cat = torch.tensor(X_train[categorical_cols].values, dtype=torch.long)
|
| 330 |
+
|
| 331 |
+
X_val_num = torch.tensor(X_val[numerical_cols].values, dtype=torch.float32)
|
| 332 |
+
X_val_cat = torch.tensor(X_val[categorical_cols].values, dtype=torch.long)
|
| 333 |
+
|
| 334 |
+
X_test_num = torch.tensor(X_test[numerical_cols].values, dtype=torch.float32)
|
| 335 |
+
X_test_cat = torch.tensor(X_test[categorical_cols].values, dtype=torch.long)
|
| 336 |
+
|
| 337 |
+
# ๋ ์ด๋ธ ๋ณํ
|
| 338 |
+
if target == "binary":
|
| 339 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32) # ์ด์ง ๋ถ๋ฅ โ float32
|
| 340 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32)
|
| 341 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)
|
| 342 |
+
elif target == "multi":
|
| 343 |
+
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long) # ๋ค์ค ๋ถ๋ฅ โ long
|
| 344 |
+
y_val_tensor = torch.tensor(y_val.values, dtype=torch.long)
|
| 345 |
+
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError("target must be 'binary' or 'multi'")
|
| 348 |
+
|
| 349 |
+
# TensorDataset ์์ฑ
|
| 350 |
+
train_dataset = TensorDataset(X_train_num, X_train_cat, y_train_tensor)
|
| 351 |
+
val_dataset = TensorDataset(X_val_num, X_val_cat, y_val_tensor)
|
| 352 |
+
test_dataset = TensorDataset(X_test_num, X_test_cat, y_test_tensor)
|
| 353 |
+
|
| 354 |
+
# DataLoader ์์ฑ (batch_size ํ๋ผ๋ฏธํฐ ์ฌ์ฉ)
|
| 355 |
+
if random_state == None:
|
| 356 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 357 |
+
else:
|
| 358 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(random_state))
|
| 359 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 360 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 361 |
+
|
| 362 |
+
return X_train, categorical_cols, numerical_cols, train_loader, val_loader, test_loader, y_train, scaler
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def calculate_csi(y_true, pred):
|
| 366 |
+
|
| 367 |
+
cm = confusion_matrix(y_true, pred) # ๋ณ์ ์ด๋ฆ์ cm์ผ๋ก ๋ณ๊ฒฝ
|
| 368 |
+
# ํผ๋ ํ๋ ฌ์์ H, F, M ์ถ์ถ
|
| 369 |
+
H = (cm[0, 0] + cm[1, 1])
|
| 370 |
+
|
| 371 |
+
F = (cm[1, 0] + cm[2, 0] +
|
| 372 |
+
cm[0, 1] + cm[2, 1])
|
| 373 |
+
|
| 374 |
+
M = (cm[0, 2] + cm[1, 2])
|
| 375 |
+
|
| 376 |
+
# CSI ๊ณ์ฐ
|
| 377 |
+
CSI = H / (H + F + M + 1e-10)
|
| 378 |
+
return CSI
|
| 379 |
+
|
| 380 |
+
def sample_weight(y_train):
|
| 381 |
+
class_weights = compute_class_weight(
|
| 382 |
+
class_weight='balanced',
|
| 383 |
+
classes=np.unique(y_train), # ๊ณ ์ ํด๋์ค
|
| 384 |
+
y=y_train # ํ์ต ๋ฐ์ดํฐ ๋ ์ด๋ธ
|
| 385 |
+
)
|
| 386 |
+
sample_weights = np.array([class_weights[label] for label in y_train])
|
| 387 |
+
|
| 388 |
+
return sample_weights
|
| 389 |
+
|
| 390 |
+
# ํ์ดํผํ๋ผ๋ฏธํฐ ์ต์ ํ ํจ์ ์ ์
|
| 391 |
+
def objective(trial, model_choose, region, data_sample='pure', target='multi', n_folds=3, random_state=42):
|
| 392 |
+
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ device ์ค์
|
| 393 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 394 |
+
val_scores = []
|
| 395 |
+
|
| 396 |
+
# --- 1. ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฒ์ ์ ์ (์์ ๋จ) ---
|
| 397 |
+
if model_choose == "ft_transformer":
|
| 398 |
+
d_token = trial.suggest_int("d_token", 64, 256, step=32)
|
| 399 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 6) # ๊น์ด ์ถ์๋ก ๊ณผ์ ํฉ ๋ฐฉ์ง
|
| 400 |
+
n_heads = trial.suggest_categorical("n_heads", [4, 8])
|
| 401 |
+
# d_token์ n_heads์ ๋ฐฐ์์ฌ์ผ ํจ (FT-Transformer์ ๊ตฌ์กฐ์ ์ ์ฝ ๋์)
|
| 402 |
+
if d_token % n_heads != 0:
|
| 403 |
+
d_token = (d_token // n_heads) * n_heads
|
| 404 |
+
|
| 405 |
+
attention_dropout = trial.suggest_float("attention_dropout", 0.1, 0.4)
|
| 406 |
+
ffn_dropout = trial.suggest_float("ffn_dropout", 0.1, 0.4)
|
| 407 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # ๋ฒ์ ํ๋
|
| 408 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 409 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 410 |
+
|
| 411 |
+
elif model_choose == 'resnet_like':
|
| 412 |
+
d_main = trial.suggest_int("d_main", 64, 256, step=32)
|
| 413 |
+
d_hidden = trial.suggest_int("d_hidden", 64, 512, step=64)
|
| 414 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 5) # ๋๋ฌด ๊น์ง ์๊ฒ ์กฐ์
|
| 415 |
+
dropout_first = trial.suggest_float("dropout_first", 0.1, 0.4)
|
| 416 |
+
dropout_second = trial.suggest_float("dropout_second", 0.0, 0.2)
|
| 417 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
| 418 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 419 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 420 |
+
|
| 421 |
+
elif model_choose == 'deepgbm':
|
| 422 |
+
# DeepGBM์ ๊ฒฝ์ฐ ๋ชจ๋ธ ํน์ฑ์ ๋ง์ถฐ ResNet ๋ธ๋ก ๋ฐ ์๋ฒ ๋ฉ ์ฐจ์ ์กฐ์
|
| 423 |
+
d_main = trial.suggest_int("d_main", 64, 256, step=32)
|
| 424 |
+
d_hidden = trial.suggest_int("d_hidden", 64, 256, step=64)
|
| 425 |
+
n_blocks = trial.suggest_int("n_blocks", 2, 6)
|
| 426 |
+
dropout = trial.suggest_float("dropout", 0.1, 0.4)
|
| 427 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
| 428 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True) # ๋ ๊ณต๊ฒฉ์ ์ธ ๋ฒ์๋ก ํ์ฅ
|
| 429 |
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch Size ์ถ๊ฐ
|
| 430 |
+
|
| 431 |
+
# --- 2. Fold๋ณ ํ์ต ๋ฐ ๊ต์ฐจ ๊ฒ์ฆ ---
|
| 432 |
+
for fold in range(1, n_folds + 1):
|
| 433 |
+
X_train_df, categorical_cols, numerical_cols, train_loader, val_loader, _, y_train, _ = prepare_dataloader_with_batchsize(
|
| 434 |
+
region, data_sample=data_sample, target=target, fold=fold, random_state=random_state, batch_size=batch_size
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
| 438 |
+
if model_choose == "ft_transformer":
|
| 439 |
+
model = FTTransformer(
|
| 440 |
+
num_features=len(numerical_cols),
|
| 441 |
+
cat_cardinalities=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 442 |
+
d_token=d_token,
|
| 443 |
+
n_blocks=n_blocks,
|
| 444 |
+
n_heads=n_heads,
|
| 445 |
+
attention_dropout=attention_dropout,
|
| 446 |
+
ffn_dropout=ffn_dropout,
|
| 447 |
+
num_classes=3
|
| 448 |
+
).to(device)
|
| 449 |
+
elif model_choose == 'resnet_like':
|
| 450 |
+
input_dim = len(numerical_cols) + len(categorical_cols)
|
| 451 |
+
model = ResNetLike(
|
| 452 |
+
input_dim=input_dim,
|
| 453 |
+
d_main=d_main,
|
| 454 |
+
d_hidden=d_hidden,
|
| 455 |
+
n_blocks=n_blocks,
|
| 456 |
+
dropout_first=dropout_first,
|
| 457 |
+
dropout_second=dropout_second,
|
| 458 |
+
num_classes=3
|
| 459 |
+
).to(device)
|
| 460 |
+
elif model_choose == 'deepgbm':
|
| 461 |
+
model = DeepGBM(
|
| 462 |
+
num_features=len(numerical_cols),
|
| 463 |
+
cat_features=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 464 |
+
d_main=d_main,
|
| 465 |
+
d_hidden=d_hidden,
|
| 466 |
+
n_blocks=n_blocks,
|
| 467 |
+
dropout=dropout,
|
| 468 |
+
num_classes=3
|
| 469 |
+
).to(device)
|
| 470 |
+
|
| 471 |
+
# ํด๋์ค ๊ฐ์ค์น ๊ณ์ฐ ๋ฐ ์์ค ํจ์ ์ค์ (Label Smoothing ์ ์ฉ)
|
| 472 |
+
if target == 'multi':
|
| 473 |
+
class_weights = compute_class_weight(
|
| 474 |
+
class_weight='balanced',
|
| 475 |
+
classes=np.unique(y_train),
|
| 476 |
+
y=y_train
|
| 477 |
+
)
|
| 478 |
+
# ํด๋์ค๋ณ ๊ฐ์ค์น ๋ก๊ทธ ์ถ๋ ฅ
|
| 479 |
+
unique_classes = np.unique(y_train)
|
| 480 |
+
class_counts = {cls: np.sum(y_train == cls) for cls in unique_classes}
|
| 481 |
+
print(f" Fold {fold} - ํด๋์ค๋ณ ๊ฐ์ค์น: {dict(zip(unique_classes, class_weights))} (ํด๋์ค๋ณ ์ํ ์: {class_counts})")
|
| 482 |
+
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
|
| 483 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.0) # Label Smoothing ์ถ๊ฐ
|
| 484 |
+
else:
|
| 485 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 486 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 487 |
+
|
| 488 |
+
# ํ์ต๋ฅ ์ค์ผ์ค๋ฌ ์ถ๊ฐ: ์ฑ๋ฅ ์ ์ฒด ์ LR์ 0.5๋ฐฐ ๊ฐ์ (๊ฒ์ฆ CSI ๊ธฐ์ค)
|
| 489 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
|
| 490 |
+
|
| 491 |
+
# ํ์ต ์ค์ (์ํญ ๋ฐ ํ์ด์
์ค ์ํฅ)
|
| 492 |
+
epochs = 200
|
| 493 |
+
patience = 12 # ๋ฅ๋ฌ๋์ ์ ์ฒด ๊ตฌ๊ฐ์ ๊ณ ๋ คํ์ฌ ์ํญ ์ํฅ
|
| 494 |
+
best_fold_csi = 0
|
| 495 |
+
counter = 0
|
| 496 |
+
|
| 497 |
+
for epoch in range(epochs):
|
| 498 |
+
model.train()
|
| 499 |
+
for x_num_batch, x_cat_batch, y_batch in train_loader:
|
| 500 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 501 |
+
|
| 502 |
+
optimizer.zero_grad()
|
| 503 |
+
y_pred = model(x_num_batch, x_cat_batch)
|
| 504 |
+
loss = criterion(y_pred, y_batch if target == 'multi' else y_batch.float())
|
| 505 |
+
loss.backward()
|
| 506 |
+
optimizer.step()
|
| 507 |
+
|
| 508 |
+
# Validation ํ๊ฐ
|
| 509 |
+
model.eval()
|
| 510 |
+
y_pred_val, y_true_val = [], []
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
for x_num_batch, x_cat_batch, y_batch in val_loader:
|
| 513 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 514 |
+
output = model(x_num_batch, x_cat_batch)
|
| 515 |
+
pred = output.argmax(dim=1) if target == 'multi' else (torch.sigmoid(output) >= 0.5).long()
|
| 516 |
+
|
| 517 |
+
y_pred_val.extend(pred.cpu().numpy())
|
| 518 |
+
y_true_val.extend(y_batch.cpu().numpy())
|
| 519 |
+
|
| 520 |
+
# CSI ๊ณ์ฐ ๋ฐ ์ค์ผ์ค๋ฌ ์
๋ฐ์ดํธ
|
| 521 |
+
val_csi = calculate_csi(y_true_val, y_pred_val)
|
| 522 |
+
scheduler.step(val_csi)
|
| 523 |
+
|
| 524 |
+
# Optuna Pruning ์ ์ฉ (์ฒซ ๋ฒ์งธ Fold์์ ์กฐ๊ธฐ ์ข
๋ฃ ํ๋จ ๊ฐํ)
|
| 525 |
+
trial.report(val_csi, epoch)
|
| 526 |
+
if trial.should_prune():
|
| 527 |
+
raise optuna.exceptions.TrialPruned()
|
| 528 |
+
|
| 529 |
+
# Early Stopping ์ฒดํฌ
|
| 530 |
+
if val_csi > best_fold_csi:
|
| 531 |
+
best_fold_csi = val_csi
|
| 532 |
+
counter = 0
|
| 533 |
+
else:
|
| 534 |
+
counter += 1
|
| 535 |
+
|
| 536 |
+
if counter >= patience:
|
| 537 |
+
break
|
| 538 |
+
|
| 539 |
+
val_scores.append(best_fold_csi)
|
| 540 |
+
|
| 541 |
+
# ๋ชจ๋ fold์ ํ๊ท ์ฑ๋ฅ ๋ฐํ
|
| 542 |
+
return np.mean(val_scores)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ํจ์
|
| 546 |
+
def train_final_model(best_params, model_choose, region, data_sample='pure', target='multi', n_folds=3, random_state=42):
|
| 547 |
+
"""
|
| 548 |
+
์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ์ ํ์ตํ๊ณ ์ ์ฅํฉ๋๋ค.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
best_params: ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ ๋์
๋๋ฆฌ
|
| 552 |
+
model_choose: ๋ชจ๋ธ ์ ํ ('ft_transformer', 'resnet_like', 'deepgbm')
|
| 553 |
+
region: ์ง์ญ๋ช
|
| 554 |
+
data_sample: ๋ฐ์ดํฐ ์ํ ํ์
('pure', 'smote', etc.)
|
| 555 |
+
target: ํ๊ฒ ํ์
('multi', 'binary')
|
| 556 |
+
n_folds: ๊ต์ฐจ ๊ฒ์ฆ fold ์
|
| 557 |
+
random_state: ๋๋ค ์๋
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก ๋ฆฌ์คํธ
|
| 561 |
+
"""
|
| 562 |
+
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ device ์ค์
|
| 563 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 564 |
+
|
| 565 |
+
models = []
|
| 566 |
+
scalers = [] # scaler ๋ฆฌ์คํธ ์ถ๊ฐ
|
| 567 |
+
|
| 568 |
+
print("์ต์ข
๋ชจ๋ธ ํ์ต ์์...")
|
| 569 |
+
|
| 570 |
+
for fold in range(1, n_folds + 1):
|
| 571 |
+
print(f"Fold {fold} ํ์ต ์ค...")
|
| 572 |
+
|
| 573 |
+
# ์ต์ ํ๋ batch_size ์ฌ์ฉ
|
| 574 |
+
batch_size = best_params.get("batch_size", 64)
|
| 575 |
+
X_train_df, categorical_cols, numerical_cols, train_loader, val_loader, _, y_train, scaler = prepare_dataloader_with_batchsize(
|
| 576 |
+
region, data_sample=data_sample, target=target, fold=fold, random_state=random_state, batch_size=batch_size
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
| 580 |
+
if model_choose == "ft_transformer":
|
| 581 |
+
d_token = best_params["d_token"]
|
| 582 |
+
n_heads = best_params.get("n_heads", 8)
|
| 583 |
+
# d_token์ n_heads์ ๋ฐฐ์์ฌ์ผ ํจ (FT-Transformer์ ๊ตฌ์กฐ์ ์ ์ฝ ๋์)
|
| 584 |
+
if d_token % n_heads != 0:
|
| 585 |
+
d_token = (d_token // n_heads) * n_heads
|
| 586 |
+
|
| 587 |
+
model = FTTransformer(
|
| 588 |
+
num_features=len(numerical_cols),
|
| 589 |
+
cat_cardinalities=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 590 |
+
d_token=d_token,
|
| 591 |
+
n_blocks=best_params["n_blocks"],
|
| 592 |
+
n_heads=n_heads,
|
| 593 |
+
attention_dropout=best_params["attention_dropout"],
|
| 594 |
+
ffn_dropout=best_params["ffn_dropout"],
|
| 595 |
+
num_classes=3
|
| 596 |
+
).to(device)
|
| 597 |
+
elif model_choose == 'resnet_like':
|
| 598 |
+
input_dim = len(numerical_cols) + len(categorical_cols)
|
| 599 |
+
model = ResNetLike(
|
| 600 |
+
input_dim=input_dim,
|
| 601 |
+
d_main=best_params["d_main"],
|
| 602 |
+
d_hidden=best_params["d_hidden"],
|
| 603 |
+
n_blocks=best_params["n_blocks"],
|
| 604 |
+
dropout_first=best_params["dropout_first"],
|
| 605 |
+
dropout_second=best_params["dropout_second"],
|
| 606 |
+
num_classes=3
|
| 607 |
+
).to(device)
|
| 608 |
+
elif model_choose == 'deepgbm':
|
| 609 |
+
model = DeepGBM(
|
| 610 |
+
num_features=len(numerical_cols),
|
| 611 |
+
cat_features=[len(X_train_df[col].unique()) for col in categorical_cols],
|
| 612 |
+
d_main=best_params["d_main"],
|
| 613 |
+
d_hidden=best_params["d_hidden"],
|
| 614 |
+
n_blocks=best_params["n_blocks"],
|
| 615 |
+
dropout=best_params["dropout"],
|
| 616 |
+
num_classes=3
|
| 617 |
+
).to(device)
|
| 618 |
+
else:
|
| 619 |
+
raise ValueError(f"Unknown model_choose: {model_choose}")
|
| 620 |
+
|
| 621 |
+
# ํด๋์ค ๊ฐ์ค์น ๊ณ์ฐ ๋ฐ ์์ค ํจ์ ์ค์ (Label Smoothing ์ ์ฉ)
|
| 622 |
+
if target == 'multi':
|
| 623 |
+
class_weights = compute_class_weight(
|
| 624 |
+
class_weight='balanced',
|
| 625 |
+
classes=np.unique(y_train),
|
| 626 |
+
y=y_train
|
| 627 |
+
)
|
| 628 |
+
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
|
| 629 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=0.0) # Label Smoothing ์ถ๊ฐ
|
| 630 |
+
else:
|
| 631 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 632 |
+
optimizer = optim.AdamW(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
|
| 633 |
+
|
| 634 |
+
# ํ์ต๋ฅ ์ค์ผ์ค๋ฌ
|
| 635 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
|
| 636 |
+
|
| 637 |
+
# ํ์ต ์ค์
|
| 638 |
+
epochs = 200
|
| 639 |
+
patience = 12
|
| 640 |
+
best_fold_csi = 0
|
| 641 |
+
counter = 0
|
| 642 |
+
best_model = None
|
| 643 |
+
|
| 644 |
+
for epoch in range(epochs):
|
| 645 |
+
model.train()
|
| 646 |
+
for x_num_batch, x_cat_batch, y_batch in train_loader:
|
| 647 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 648 |
+
|
| 649 |
+
optimizer.zero_grad()
|
| 650 |
+
y_pred = model(x_num_batch, x_cat_batch)
|
| 651 |
+
loss = criterion(y_pred, y_batch if target == 'multi' else y_batch.float())
|
| 652 |
+
loss.backward()
|
| 653 |
+
optimizer.step()
|
| 654 |
+
|
| 655 |
+
# Validation ํ๊ฐ
|
| 656 |
+
model.eval()
|
| 657 |
+
y_pred_val, y_true_val = [], []
|
| 658 |
+
with torch.no_grad():
|
| 659 |
+
for x_num_batch, x_cat_batch, y_batch in val_loader:
|
| 660 |
+
x_num_batch, x_cat_batch, y_batch = x_num_batch.to(device), x_cat_batch.to(device), y_batch.to(device)
|
| 661 |
+
output = model(x_num_batch, x_cat_batch)
|
| 662 |
+
pred = output.argmax(dim=1) if target == 'multi' else (torch.sigmoid(output) >= 0.5).long()
|
| 663 |
+
|
| 664 |
+
y_pred_val.extend(pred.cpu().numpy())
|
| 665 |
+
y_true_val.extend(y_batch.cpu().numpy())
|
| 666 |
+
|
| 667 |
+
# CSI ๊ณ์ฐ ๋ฐ ์ค์ผ์ค๋ฌ ์
๋ฐ์ดํธ
|
| 668 |
+
val_csi = calculate_csi(y_true_val, y_pred_val)
|
| 669 |
+
scheduler.step(val_csi)
|
| 670 |
+
|
| 671 |
+
# Early Stopping ์ฒดํฌ
|
| 672 |
+
if val_csi > best_fold_csi:
|
| 673 |
+
best_fold_csi = val_csi
|
| 674 |
+
counter = 0
|
| 675 |
+
best_model = copy.deepcopy(model)
|
| 676 |
+
else:
|
| 677 |
+
counter += 1
|
| 678 |
+
|
| 679 |
+
if counter >= patience:
|
| 680 |
+
print(f" Early stopping at epoch {epoch+1}, Best CSI: {best_fold_csi:.4f}")
|
| 681 |
+
break
|
| 682 |
+
|
| 683 |
+
if best_model is None:
|
| 684 |
+
best_model = model
|
| 685 |
+
|
| 686 |
+
scalers.append(scaler) # scaler ์ ์ฅ (fold ์์๋๋ก)
|
| 687 |
+
models.append(best_model)
|
| 688 |
+
print(f" Fold {fold} ํ์ต ์๋ฃ (๊ฒ์ฆ CSI: {best_fold_csi:.4f})")
|
| 689 |
+
|
| 690 |
+
# ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก ์ค์
|
| 691 |
+
save_dir = f'../save_model/{model_choose}_optima'
|
| 692 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 693 |
+
|
| 694 |
+
# ํ์ผ๋ช
์์ฑ
|
| 695 |
+
if data_sample == 'pure':
|
| 696 |
+
model_filename = f'{model_choose}_pure_{region}.pkl'
|
| 697 |
+
else:
|
| 698 |
+
model_filename = f'{model_choose}_{data_sample}_{region}.pkl'
|
| 699 |
+
|
| 700 |
+
model_path = f'{save_dir}/{model_filename}'
|
| 701 |
+
|
| 702 |
+
# ๋ฆฌ์คํธ์ ๋ด์ ํ ๋ฒ์ ์ ์ฅ
|
| 703 |
+
joblib.dump(models, model_path)
|
| 704 |
+
print(f"\n๋ชจ๋ ๋ชจ๋ธ ์ ์ฅ ์๋ฃ: {model_path} (์ด {len(models)}๊ฐ fold)")
|
| 705 |
+
|
| 706 |
+
# Scaler ๋ณ๋ ์ ์ฅ
|
| 707 |
+
scaler_save_dir = f'../save_model/{model_choose}_optima/scaler'
|
| 708 |
+
os.makedirs(scaler_save_dir, exist_ok=True)
|
| 709 |
+
|
| 710 |
+
# ํ์ผ๋ช
์์ฑ (๋ชจ๋ธ๊ณผ ๋์ผํ ํจํด)
|
| 711 |
+
if data_sample == 'pure':
|
| 712 |
+
scaler_filename = f'{model_choose}_pure_{region}_scaler.pkl'
|
| 713 |
+
else:
|
| 714 |
+
scaler_filename = f'{model_choose}_{data_sample}_{region}_scaler.pkl'
|
| 715 |
+
|
| 716 |
+
scaler_path = f'{scaler_save_dir}/{scaler_filename}'
|
| 717 |
+
joblib.dump(scalers, scaler_path)
|
| 718 |
+
print(f"Scaler ์ ์ฅ ์๋ฃ: {scaler_path} (์ด {len(scalers)}๊ฐ fold)")
|
| 719 |
+
|
| 720 |
+
return model_path
|
Analysis_code/5.optima/deepgbm_smotenc_ctgan20000/deepgbm_smotenc_ctgan20000_busan.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="busan", data_sample='smotenc_ctgan20000'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smotenc_ctgan20000_busan_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="busan",
|
| 80 |
+
data_sample='smotenc_ctgan20000',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|
Analysis_code/5.optima/deepgbm_smotenc_ctgan20000/deepgbm_smotenc_ctgan20000_daegu.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import joblib
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
# Python ๋ฐ Numpy ์๋ ๊ณ ์
|
| 10 |
+
seed = 42
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 1. Study ์์ฑ ์ 'maximize'๋ก ์ค์
|
| 16 |
+
study = optuna.create_study(
|
| 17 |
+
direction="maximize", # CSI ์ ์๊ฐ ๋์์๋ก ์ข์ผ๋ฏ๋ก maximize
|
| 18 |
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10) # ์ด๋ฐ 10์ํญ์ ์ง์ผ๋ณด๊ณ ์ดํ ๊ฐ์ง์น๊ธฐ
|
| 19 |
+
)
|
| 20 |
+
# Trial ์๋ฃ ์ ์์ธ ์ ๋ณด ์ถ๋ ฅํ๋ callback ํจ์
|
| 21 |
+
def print_trial_callback(study, trial):
|
| 22 |
+
"""๊ฐ trial ์๋ฃ ์ best value๋ฅผ ํฌํจํ ์์ธ ์ ๋ณด ์ถ๋ ฅ"""
|
| 23 |
+
print(f"\n{'='*80}")
|
| 24 |
+
print(f"Trial {trial.number} ์๋ฃ")
|
| 25 |
+
print(f" Value (CSI): {trial.value:.6f}" if trial.value is not None else f" Value: {trial.value}")
|
| 26 |
+
print(f" Parameters: {trial.params}")
|
| 27 |
+
print(f" Best Value (CSI): {study.best_value:.6f}" if study.best_value is not None else f" Best Value: {study.best_value}")
|
| 28 |
+
print(f" Best Trial: {study.best_trial.number}")
|
| 29 |
+
print(f" Best Parameters: {study.best_params}")
|
| 30 |
+
print(f"{'='*80}\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 2. ์ต์ ํ ์คํ
|
| 35 |
+
study.optimize(
|
| 36 |
+
lambda trial: objective(trial, model_choose="deepgbm", region="daegu", data_sample='smotenc_ctgan20000'),
|
| 37 |
+
n_trials=100
|
| 38 |
+
,
|
| 39 |
+
callbacks=[print_trial_callback]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 3. ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์์ฝ
|
| 43 |
+
print(f"\n์ต์ ํ ์๋ฃ.")
|
| 44 |
+
print(f"Best CSI Score: {study.best_value:.4f}")
|
| 45 |
+
print(f"Best Hyperparameters: {study.best_params}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# ๋ชจ๋ trial์ CSI ์ ์ ์ถ์ถ
|
| 49 |
+
csi_scores = [trial.value for trial in study.trials if trial.value is not None]
|
| 50 |
+
|
| 51 |
+
if len(csi_scores) > 0:
|
| 52 |
+
print(f"\n์ต์ ํ ๊ณผ์ ์์ฝ:")
|
| 53 |
+
print(f" - ์ด ์๋ ํ์: {len(study.trials)}")
|
| 54 |
+
print(f" - ์ฑ๊ณตํ ์๋: {len(csi_scores)}")
|
| 55 |
+
print(f" - ์ต์ด CSI: {csi_scores[0]:.4f}")
|
| 56 |
+
print(f" - ์ต์ข
CSI: {csi_scores[-1]:.4f}")
|
| 57 |
+
print(f" - ์ต๊ณ CSI: {max(csi_scores):.4f}")
|
| 58 |
+
print(f" - ์ต์ CSI: {min(csi_scores):.4f}")
|
| 59 |
+
print(f" - ํ๊ท CSI: {np.mean(csi_scores):.4f}")
|
| 60 |
+
|
| 61 |
+
# Study ๊ฐ์ฒด ์ ์ฅ
|
| 62 |
+
# ํ์ผ ์์น ๊ธฐ๋ฐ์ผ๋ก base ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
|
| 63 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
+
base_dir = os.path.dirname(os.path.dirname(current_file_dir)) # 5.optima ๋๋ ํ ๋ฆฌ
|
| 65 |
+
os.makedirs(os.path.join(base_dir, "optimization_history"), exist_ok=True)
|
| 66 |
+
study_path = os.path.join(base_dir, "optimization_history/deepgbm_smotenc_ctgan20000_daegu_trials.pkl")
|
| 67 |
+
joblib.dump(study, study_path)
|
| 68 |
+
print(f"\n์ต์ ํ Study ๊ฐ์ฒด๊ฐ {study_path}์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 69 |
+
|
| 70 |
+
# ์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ
|
| 71 |
+
print("\n" + "="*50)
|
| 72 |
+
print("์ต์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ์ต์ข
๋ชจ๋ธ ํ์ต ์์")
|
| 73 |
+
print("="*50)
|
| 74 |
+
|
| 75 |
+
best_params = study.best_params
|
| 76 |
+
model_path = train_final_model(
|
| 77 |
+
best_params=best_params,
|
| 78 |
+
model_choose="deepgbm",
|
| 79 |
+
region="daegu",
|
| 80 |
+
data_sample='smotenc_ctgan20000',
|
| 81 |
+
target='multi',
|
| 82 |
+
n_folds=3,
|
| 83 |
+
random_state=seed
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
print(f"\n์ต์ข
๋ชจ๋ธ ํ์ต ๋ฐ ์ ์ฅ ์๋ฃ!")
|
| 87 |
+
print(f"์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"\nโ ๏ธ ์ต์ ํ ๊ฒฐ๊ณผ ๋ถ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
|
| 94 |
+
# ์ ์ ์ข
๋ฃ
|
| 95 |
+
import sys
|
| 96 |
+
sys.exit(0)
|
| 97 |
+
|