Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +65 -0
- models/.ipynb_checkpoints/Untitled-checkpoint.ipynb +766 -0
- models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb +305 -0
- models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb +485 -0
- models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb +372 -0
- models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb +483 -0
- models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb +323 -0
- models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb +500 -0
- models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb +460 -0
- models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb +326 -0
- models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb +3 -0
- models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb +273 -0
- models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/resnet_model-checkpoint.py +160 -0
- models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py +166 -0
- models/.ipynb_checkpoints/train-checkpoint.py +105 -0
- models/.ipynb_checkpoints/train-mask-8-checkpoint.py +103 -0
- models/.ipynb_checkpoints/train-mask-checkpoint.py +104 -0
- models/.ipynb_checkpoints/utils-checkpoint.py +393 -0
- models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py +65 -0
- models/HITS-FEB-10.zip +3 -0
- models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png +3 -0
- models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739230556_9.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739230556_9.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739231399_1.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739231399_1.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739231802_11.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739231802_11.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_13.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_13.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_14.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_14.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739235333_29.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739235333_29.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739235841_12.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739235841_12.png +3 -0
- models/HITS-FEB-10/hit_50233055_1739232802_29.npy +3 -0
- models/HITS-FEB-10/hit_50233055_1739232802_29.png +3 -0
- models/HITS-FEB-10/hit_52111435_1739229641_28.npy +3 -0
- models/HITS-FEB-10/hit_52111435_1739229641_28.png +3 -0
- models/HITS-FEB-10/hit_52550001_1739233595_4.npy +3 -0
- models/HITS-FEB-10/hit_52550001_1739233595_4.png +3 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,68 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
models/HITS-FEB-10/hit_100000000_1739230556_9.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
models/HITS-FEB-10/hit_100000000_1739231399_1.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
models/HITS-FEB-10/hit_100000000_1739231802_11.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
models/HITS-FEB-10/hit_100000000_1739234628_13.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/HITS-FEB-10/hit_100000000_1739234628_14.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/HITS-FEB-10/hit_100000000_1739235333_29.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
models/HITS-FEB-10/hit_100000000_1739235841_12.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
models/HITS-FEB-10/hit_50233055_1739232802_29.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
models/HITS-FEB-10/hit_52111435_1739229641_28.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
models/HITS-FEB-10/hit_52550001_1739233595_4.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
models/HITS-FEB-10/hit_57096732_1739234611_11.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
models/HITS-FEB-10/hit_57521253_1739232651_11.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
models/HITS-FEB-10/hit_58032264_1739232672_22.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
models/HITS-FEB-10/hit_58165746_1739230560_10.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
models/HITS-FEB-10/hit_62177701_1739230031_23.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
models/HITS-FEB-10/hit_64237575_1739233604_2.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
models/HITS-FEB-10/hit_67249737_1739231769_23.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
models/HITS-FEB-10/hit_71882680_1739230648_29.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
models/HITS-FEB-10/hit_72677566_1739232113_26.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
models/HITS-FEB-10/hit_74160848_1739234611_5.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
models/HITS-FEB-10/hit_75109552_1739231790_10.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
models/HITS-FEB-10/hit_79640130_1739231950_5.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
models/HITS-FEB-10/hit_81910572_1739231764_29.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
models/HITS-FEB-10/hit_83296520_1739233906_31.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
models/HITS-FEB-10/hit_84171229_1739231886_5.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
models/HITS-FEB-10/hit_84411238_1739233784_3.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
models/HITS-FEB-10/hit_87957837_1739232059_29.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
models/HITS-FEB-10/hit_88699241_1739235027_30.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
models/HITS-FEB-10/hit_90233018_1739235740_0.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
models/HITS-FEB-10/hit_93808281_1739233104_14.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
models/HITS-FEB-10/hit_93821122_1739232374_4.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
models/HITS-FEB-10/hit_94705507_1739231215_6.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
models/HITS-FEB-10/hit_95400645_1739230724_21.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
models/HITS-FEB-10/hit_95544329_1739233784_17.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
models/HITS-FEB-10/hit_96119644_1739232285_1.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
models/HITS-FEB-10/hit_96369222_1739233852_3.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
models/HITS-FEB-10/hit_96470689_1739233843_18.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
models/HITS-FEB-10/hit_96471079_1739232857_0.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
models/HITS-FEB-10/hit_96497133_1739233692_27.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
models/HITS-FEB-10/hit_98139322_1739231500_5.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
models/HITS-FEB-10/hit_98582385_1739232894_12.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
models/HITS-FEB-10/hit_98697207_1739229930_16.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
models/HITS-FEB-10/hit_99172221_1739232365_29.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
models/HITS-FEB-10/hit_99314646_1739233667_1.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
models/HITS-FEB-10/hit_99756914_1739230207_8.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
models/HITS-FEB-10/hit_99939211_1739233705_24.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
models/HITS-FEB-10/hit_99972041_1739234066_9.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
models/HITS-FEB-10/hit_99977277_1739231773_10.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
models/HITS-FEB-10/hit_99986237_1739234058_8.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
models/HITS-FEB-10/hit_99998032_1739232348_29.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
models/HITS-FEB-10/hit_99999287_1739233700_3.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
models/HITS-FEB-10/hit_99999351_1739235476_0.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
models/HITS-FEB-10/hit_99999979_1739232399_15.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
models/combined_frb_detections.pdf filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
models/combined_frb_detections.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
models/hits.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
models/hits_crab.pdf filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
models/models_mask/accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
models/models_mask/accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
models/models_mask/accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
models/recover_crab.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
models/recover_new_crab-debug.ipynb filter=lfs diff=lfs merge=lfs -text
|
models/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 10,
|
| 6 |
+
"id": "5577ffee-a5c9-4648-8849-95c2c7ebcebe",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 11 |
+
"from utils_batched_preproc import transform_batched, preproc_flip\n",
|
| 12 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 13 |
+
"import torch\n",
|
| 14 |
+
"import numpy as np\n",
|
| 15 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 16 |
+
"import torch\n",
|
| 17 |
+
"import torch.nn as nn\n",
|
| 18 |
+
"import torch.optim as optim\n",
|
| 19 |
+
"import tqdm \n",
|
| 20 |
+
"import torch.nn.functional as F\n",
|
| 21 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 22 |
+
"import pickle\n",
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"from functorch import vmap"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": 3,
|
| 30 |
+
"id": "f1180d60-83e7-47ca-aa09-58d26af3c706",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# def renorm_batched(data):\n",
|
| 35 |
+
"# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
|
| 36 |
+
"# std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
|
| 37 |
+
"# standardized_data = (data - mean) / std\n",
|
| 38 |
+
"# return standardized_data\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# def transform_batched(data):\n",
|
| 41 |
+
"# copy_data = data.detach().clone()\n",
|
| 42 |
+
"# rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std\n",
|
| 43 |
+
"# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean\n",
|
| 44 |
+
"# masks_rms = [-1, 5]\n",
|
| 45 |
+
" \n",
|
| 46 |
+
"# # Prepare the new_data tensor\n",
|
| 47 |
+
"# num_masks = len(masks_rms) + 1\n",
|
| 48 |
+
"# new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# # First layer: Apply renorm(log10(copy_data + epsilon))\n",
|
| 51 |
+
"# new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
|
| 52 |
+
"# for i, scale in enumerate(masks_rms, start=1):\n",
|
| 53 |
+
"# copy_data = data.detach().clone()\n",
|
| 54 |
+
" \n",
|
| 55 |
+
"# # Apply masking based on the scale\n",
|
| 56 |
+
"# if scale < 0:\n",
|
| 57 |
+
"# ind = copy_data < abs(scale) * rms + mean\n",
|
| 58 |
+
"# else:\n",
|
| 59 |
+
"# ind = copy_data > scale * rms + mean\n",
|
| 60 |
+
"# copy_data[ind] = 0\n",
|
| 61 |
+
" \n",
|
| 62 |
+
"# # Renormalize and log10 transform\n",
|
| 63 |
+
"# new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
|
| 64 |
+
" \n",
|
| 65 |
+
"# # Convert to float32\n",
|
| 66 |
+
"# new_data = new_data.type(torch.float32)\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# # Chunk along the last dimension and stack\n",
|
| 69 |
+
"# slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing\n",
|
| 70 |
+
"# new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1\n",
|
| 71 |
+
"# new_data = torch.swapaxes(new_data, 0,1)\n",
|
| 72 |
+
"# # Reshape into final format\n",
|
| 73 |
+
"# new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions\n",
|
| 74 |
+
"# return new_data\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"\n"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": 11,
|
| 82 |
+
"id": "81cc81a8-ecef-43ef-a5cd-c35765384812",
|
| 83 |
+
"metadata": {
|
| 84 |
+
"scrolled": true
|
| 85 |
+
},
|
| 86 |
+
"outputs": [
|
| 87 |
+
{
|
| 88 |
+
"name": "stdout",
|
| 89 |
+
"output_type": "stream",
|
| 90 |
+
"text": [
|
| 91 |
+
"num params encoder 50840\n"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"name": "stderr",
|
| 96 |
+
"output_type": "stream",
|
| 97 |
+
"text": [
|
| 98 |
+
"/tmp/ipykernel_19147/1680389579.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 99 |
+
" model.load_state_dict(torch.load(model_path))\n"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"data": {
|
| 104 |
+
"text/plain": [
|
| 105 |
+
"DataParallel(\n",
|
| 106 |
+
" (module): ResNet(\n",
|
| 107 |
+
" (relu): ReLU()\n",
|
| 108 |
+
" (conv1): Sequential(\n",
|
| 109 |
+
" (0): Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
|
| 110 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 111 |
+
" (2): ReLU()\n",
|
| 112 |
+
" )\n",
|
| 113 |
+
" (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
| 114 |
+
" (layer0): Sequential(\n",
|
| 115 |
+
" (0): ResidualBlock(\n",
|
| 116 |
+
" (conv1): Sequential(\n",
|
| 117 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 118 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 119 |
+
" (2): ReLU()\n",
|
| 120 |
+
" )\n",
|
| 121 |
+
" (conv2): Sequential(\n",
|
| 122 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 123 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 124 |
+
" )\n",
|
| 125 |
+
" (relu): ReLU()\n",
|
| 126 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 127 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 128 |
+
" )\n",
|
| 129 |
+
" (1): ResidualBlock(\n",
|
| 130 |
+
" (conv1): Sequential(\n",
|
| 131 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 132 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 133 |
+
" (2): ReLU()\n",
|
| 134 |
+
" )\n",
|
| 135 |
+
" (conv2): Sequential(\n",
|
| 136 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 137 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 138 |
+
" )\n",
|
| 139 |
+
" (relu): ReLU()\n",
|
| 140 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 141 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 142 |
+
" )\n",
|
| 143 |
+
" (2): ResidualBlock(\n",
|
| 144 |
+
" (conv1): Sequential(\n",
|
| 145 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 146 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 147 |
+
" (2): ReLU()\n",
|
| 148 |
+
" )\n",
|
| 149 |
+
" (conv2): Sequential(\n",
|
| 150 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 151 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 152 |
+
" )\n",
|
| 153 |
+
" (relu): ReLU()\n",
|
| 154 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 155 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 156 |
+
" )\n",
|
| 157 |
+
" )\n",
|
| 158 |
+
" (layer1): Sequential(\n",
|
| 159 |
+
" (0): ResidualBlock(\n",
|
| 160 |
+
" (conv1): Sequential(\n",
|
| 161 |
+
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
|
| 162 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 163 |
+
" (2): ReLU()\n",
|
| 164 |
+
" )\n",
|
| 165 |
+
" (conv2): Sequential(\n",
|
| 166 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 167 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 168 |
+
" )\n",
|
| 169 |
+
" (downsample): Sequential(\n",
|
| 170 |
+
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))\n",
|
| 171 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 172 |
+
" )\n",
|
| 173 |
+
" (relu): ReLU()\n",
|
| 174 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 175 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 176 |
+
" )\n",
|
| 177 |
+
" (1): ResidualBlock(\n",
|
| 178 |
+
" (conv1): Sequential(\n",
|
| 179 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 180 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 181 |
+
" (2): ReLU()\n",
|
| 182 |
+
" )\n",
|
| 183 |
+
" (conv2): Sequential(\n",
|
| 184 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 185 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 186 |
+
" )\n",
|
| 187 |
+
" (relu): ReLU()\n",
|
| 188 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 189 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 190 |
+
" )\n",
|
| 191 |
+
" (2): ResidualBlock(\n",
|
| 192 |
+
" (conv1): Sequential(\n",
|
| 193 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 194 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 195 |
+
" (2): ReLU()\n",
|
| 196 |
+
" )\n",
|
| 197 |
+
" (conv2): Sequential(\n",
|
| 198 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 199 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 200 |
+
" )\n",
|
| 201 |
+
" (relu): ReLU()\n",
|
| 202 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 203 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 204 |
+
" )\n",
|
| 205 |
+
" (3): ResidualBlock(\n",
|
| 206 |
+
" (conv1): Sequential(\n",
|
| 207 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 208 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 209 |
+
" (2): ReLU()\n",
|
| 210 |
+
" )\n",
|
| 211 |
+
" (conv2): Sequential(\n",
|
| 212 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 213 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 214 |
+
" )\n",
|
| 215 |
+
" (relu): ReLU()\n",
|
| 216 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 217 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 218 |
+
" )\n",
|
| 219 |
+
" )\n",
|
| 220 |
+
" (layer2): Sequential(\n",
|
| 221 |
+
" (0): ResidualBlock(\n",
|
| 222 |
+
" (conv1): Sequential(\n",
|
| 223 |
+
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
|
| 224 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 225 |
+
" (2): ReLU()\n",
|
| 226 |
+
" )\n",
|
| 227 |
+
" (conv2): Sequential(\n",
|
| 228 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 229 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 230 |
+
" )\n",
|
| 231 |
+
" (downsample): Sequential(\n",
|
| 232 |
+
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))\n",
|
| 233 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 234 |
+
" )\n",
|
| 235 |
+
" (relu): ReLU()\n",
|
| 236 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 237 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 238 |
+
" )\n",
|
| 239 |
+
" (1): ResidualBlock(\n",
|
| 240 |
+
" (conv1): Sequential(\n",
|
| 241 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 242 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 243 |
+
" (2): ReLU()\n",
|
| 244 |
+
" )\n",
|
| 245 |
+
" (conv2): Sequential(\n",
|
| 246 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 247 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 248 |
+
" )\n",
|
| 249 |
+
" (relu): ReLU()\n",
|
| 250 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 251 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 252 |
+
" )\n",
|
| 253 |
+
" (2): ResidualBlock(\n",
|
| 254 |
+
" (conv1): Sequential(\n",
|
| 255 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 256 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 257 |
+
" (2): ReLU()\n",
|
| 258 |
+
" )\n",
|
| 259 |
+
" (conv2): Sequential(\n",
|
| 260 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 261 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 262 |
+
" )\n",
|
| 263 |
+
" (relu): ReLU()\n",
|
| 264 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 265 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
" (3): ResidualBlock(\n",
|
| 268 |
+
" (conv1): Sequential(\n",
|
| 269 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 270 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 271 |
+
" (2): ReLU()\n",
|
| 272 |
+
" )\n",
|
| 273 |
+
" (conv2): Sequential(\n",
|
| 274 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 275 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 276 |
+
" )\n",
|
| 277 |
+
" (relu): ReLU()\n",
|
| 278 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 279 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 280 |
+
" )\n",
|
| 281 |
+
" (4): ResidualBlock(\n",
|
| 282 |
+
" (conv1): Sequential(\n",
|
| 283 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 284 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 285 |
+
" (2): ReLU()\n",
|
| 286 |
+
" )\n",
|
| 287 |
+
" (conv2): Sequential(\n",
|
| 288 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 289 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 290 |
+
" )\n",
|
| 291 |
+
" (relu): ReLU()\n",
|
| 292 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 293 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 294 |
+
" )\n",
|
| 295 |
+
" (5): ResidualBlock(\n",
|
| 296 |
+
" (conv1): Sequential(\n",
|
| 297 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 298 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 299 |
+
" (2): ReLU()\n",
|
| 300 |
+
" )\n",
|
| 301 |
+
" (conv2): Sequential(\n",
|
| 302 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 303 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 304 |
+
" )\n",
|
| 305 |
+
" (relu): ReLU()\n",
|
| 306 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 307 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 308 |
+
" )\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" (layer3): Sequential(\n",
|
| 311 |
+
" (0): ResidualBlock(\n",
|
| 312 |
+
" (conv1): Sequential(\n",
|
| 313 |
+
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 314 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 315 |
+
" (2): ReLU()\n",
|
| 316 |
+
" )\n",
|
| 317 |
+
" (conv2): Sequential(\n",
|
| 318 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 319 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 320 |
+
" )\n",
|
| 321 |
+
" (downsample): Sequential(\n",
|
| 322 |
+
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
| 323 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 324 |
+
" )\n",
|
| 325 |
+
" (relu): ReLU()\n",
|
| 326 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 327 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 328 |
+
" )\n",
|
| 329 |
+
" (1): ResidualBlock(\n",
|
| 330 |
+
" (conv1): Sequential(\n",
|
| 331 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 332 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 333 |
+
" (2): ReLU()\n",
|
| 334 |
+
" )\n",
|
| 335 |
+
" (conv2): Sequential(\n",
|
| 336 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 337 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 338 |
+
" )\n",
|
| 339 |
+
" (relu): ReLU()\n",
|
| 340 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 341 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 342 |
+
" )\n",
|
| 343 |
+
" (2): ResidualBlock(\n",
|
| 344 |
+
" (conv1): Sequential(\n",
|
| 345 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 346 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 347 |
+
" (2): ReLU()\n",
|
| 348 |
+
" )\n",
|
| 349 |
+
" (conv2): Sequential(\n",
|
| 350 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 351 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 352 |
+
" )\n",
|
| 353 |
+
" (relu): ReLU()\n",
|
| 354 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
| 355 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
| 356 |
+
" )\n",
|
| 357 |
+
" )\n",
|
| 358 |
+
" (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
|
| 359 |
+
" (fc): Linear(in_features=39424, out_features=2, bias=True)\n",
|
| 360 |
+
" (dropout1): Dropout(p=0.3, inplace=False)\n",
|
| 361 |
+
" (encoder): Sequential(\n",
|
| 362 |
+
" (0): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 363 |
+
" (1): ReLU(inplace=True)\n",
|
| 364 |
+
" (2): Dropout(p=0.3, inplace=False)\n",
|
| 365 |
+
" (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 366 |
+
" (4): ReLU(inplace=True)\n",
|
| 367 |
+
" (5): Dropout(p=0.3, inplace=False)\n",
|
| 368 |
+
" (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 369 |
+
" (7): ReLU(inplace=True)\n",
|
| 370 |
+
" (8): Dropout(p=0.3, inplace=False)\n",
|
| 371 |
+
" (9): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
| 372 |
+
" (10): Sigmoid()\n",
|
| 373 |
+
" )\n",
|
| 374 |
+
" )\n",
|
| 375 |
+
")"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
"execution_count": 11,
|
| 379 |
+
"metadata": {},
|
| 380 |
+
"output_type": "execute_result"
|
| 381 |
+
}
|
| 382 |
+
],
|
| 383 |
+
"source": [
|
| 384 |
+
"model_path = 'models/model-47-99.125.pt'\n",
|
| 385 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 386 |
+
"\n",
|
| 387 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=2).to(device)\n",
|
| 388 |
+
"model = nn.DataParallel(model)\n",
|
| 389 |
+
"model = model.to(device)\n",
|
| 390 |
+
"model.load_state_dict(torch.load(model_path))\n",
|
| 391 |
+
"model.eval()"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"cell_type": "code",
|
| 396 |
+
"execution_count": 12,
|
| 397 |
+
"id": "58b8c338-df2f-4ef0-92cf-409c9f034cab",
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"outputs": [
|
| 400 |
+
{
|
| 401 |
+
"name": "stderr",
|
| 402 |
+
"output_type": "stream",
|
| 403 |
+
"text": [
|
| 404 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 405 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"name": "stdout",
|
| 410 |
+
"output_type": "stream",
|
| 411 |
+
"text": [
|
| 412 |
+
"tensor([[ 4.1780, -4.1750],\n",
|
| 413 |
+
" [ 4.6414, -4.6303],\n",
|
| 414 |
+
" [ 5.0103, -5.0162],\n",
|
| 415 |
+
" [ 4.8273, -4.8311],\n",
|
| 416 |
+
" [ 4.8523, -4.8661],\n",
|
| 417 |
+
" [ 4.8855, -4.9074],\n",
|
| 418 |
+
" [ 4.4973, -4.5213],\n",
|
| 419 |
+
" [ 5.5996, -5.6192],\n",
|
| 420 |
+
" [ 4.7929, -4.8116],\n",
|
| 421 |
+
" [ 5.5999, -5.5925],\n",
|
| 422 |
+
" [ 4.7918, -4.7998],\n",
|
| 423 |
+
" [ 4.0914, -4.0766],\n",
|
| 424 |
+
" [ 0.7072, -0.6955],\n",
|
| 425 |
+
" [ 4.7136, -4.7234],\n",
|
| 426 |
+
" [ 5.3918, -5.4307],\n",
|
| 427 |
+
" [ 4.5491, -4.5524],\n",
|
| 428 |
+
" [ 4.5412, -4.5391],\n",
|
| 429 |
+
" [ 4.6264, -4.6137],\n",
|
| 430 |
+
" [ 3.9378, -3.9300],\n",
|
| 431 |
+
" [ 5.0673, -5.0792],\n",
|
| 432 |
+
" [ 5.7389, -5.7330],\n",
|
| 433 |
+
" [ 5.2259, -5.2326],\n",
|
| 434 |
+
" [ 5.3856, -5.4036],\n",
|
| 435 |
+
" [ 5.0781, -5.1232],\n",
|
| 436 |
+
" [ 5.2432, -5.2584],\n",
|
| 437 |
+
" [ 5.8163, -5.8209],\n",
|
| 438 |
+
" [ 4.7730, -4.7823],\n",
|
| 439 |
+
" [ 5.1320, -5.1657],\n",
|
| 440 |
+
" [ 5.6486, -5.6485],\n",
|
| 441 |
+
" [ 3.7626, -3.7674],\n",
|
| 442 |
+
" [ 4.1834, -4.1797],\n",
|
| 443 |
+
" [ 4.4452, -4.4566]], device='cuda:0', grad_fn=<GatherBackward>)\n"
|
| 444 |
+
]
|
| 445 |
+
}
|
| 446 |
+
],
|
| 447 |
+
"source": [
|
| 448 |
+
"test_in = abs(torch.randn(32, 192, 2048).to(device))\n",
|
| 449 |
+
"results = []\n",
|
| 450 |
+
"for i in range(32):\n",
|
| 451 |
+
" results.append(transform(test_in[i,:,:]))\n",
|
| 452 |
+
"intermediate = torch.stack(results).cuda()\n",
|
| 453 |
+
"out = model(intermediate)\n",
|
| 454 |
+
"test_in.cpu().detach().numpy().tofile(\"input.bin\")\n",
|
| 455 |
+
"intermediate.cpu().detach().numpy().tofile(\"intermediate.bin\")\n",
|
| 456 |
+
"out.cpu().detach().numpy().tofile(\"output.bin\")\n",
|
| 457 |
+
"print(out)"
|
| 458 |
+
]
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"cell_type": "code",
|
| 462 |
+
"execution_count": 13,
|
| 463 |
+
"id": "ad56299a-44e4-4d6b-afcc-18a5f4cf0138",
|
| 464 |
+
"metadata": {},
|
| 465 |
+
"outputs": [
|
| 466 |
+
{
|
| 467 |
+
"ename": "NameError",
|
| 468 |
+
"evalue": "name 'preproc_flip' is not defined",
|
| 469 |
+
"output_type": "error",
|
| 470 |
+
"traceback": [
|
| 471 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 472 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
| 473 |
+
"Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m preproc_model \u001b[38;5;241m=\u001b[39m preproc_flip()\n\u001b[1;32m 2\u001b[0m Convert_ONNX(preproc_model,\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodels_mask/preproc_flip.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m, input_data_mock\u001b[38;5;241m=\u001b[39mtest_in\u001b[38;5;241m.\u001b[39mto(device))\n",
|
| 474 |
+
"\u001b[0;31mNameError\u001b[0m: name 'preproc_flip' is not defined"
|
| 475 |
+
]
|
| 476 |
+
}
|
| 477 |
+
],
|
| 478 |
+
"source": [
|
| 479 |
+
"preproc_model = preproc_flip()\n",
|
| 480 |
+
"Convert_ONNX(preproc_model,f'models_mask/preproc_flip.onnx', input_data_mock=test_in.to(device))\n",
|
| 481 |
+
"# Convert_ONNX(model.module,f'models_mask/model_test.onnx', input_data_mock=intermediate.to(device))"
|
| 482 |
+
]
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"cell_type": "code",
|
| 486 |
+
"execution_count": 7,
|
| 487 |
+
"id": "30e84a9b-0d4f-4cb2-a92b-2e3f0b2ccb20",
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"outputs": [
|
| 490 |
+
{
|
| 491 |
+
"data": {
|
| 492 |
+
"text/plain": [
|
| 493 |
+
"torch.Size([32, 192, 2048])"
|
| 494 |
+
]
|
| 495 |
+
},
|
| 496 |
+
"execution_count": 7,
|
| 497 |
+
"metadata": {},
|
| 498 |
+
"output_type": "execute_result"
|
| 499 |
+
}
|
| 500 |
+
],
|
| 501 |
+
"source": [
|
| 502 |
+
"test_in.shape"
|
| 503 |
+
]
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"cell_type": "code",
|
| 507 |
+
"execution_count": 13,
|
| 508 |
+
"id": "1bb26727-7914-470e-bb48-43d7ee81cb50",
|
| 509 |
+
"metadata": {},
|
| 510 |
+
"outputs": [
|
| 511 |
+
{
|
| 512 |
+
"data": {
|
| 513 |
+
"text/plain": [
|
| 514 |
+
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
|
| 515 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
| 516 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
| 517 |
+
" ...,\n",
|
| 518 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
| 519 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
| 520 |
+
" [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')"
|
| 521 |
+
]
|
| 522 |
+
},
|
| 523 |
+
"execution_count": 13,
|
| 524 |
+
"metadata": {},
|
| 525 |
+
"output_type": "execute_result"
|
| 526 |
+
}
|
| 527 |
+
],
|
| 528 |
+
"source": [
|
| 529 |
+
"import torch\n",
|
| 530 |
+
"torch.flip(test_in[0,:,:], dims = (0,)) - torch.flipud(test_in[0,:,:])"
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"cell_type": "code",
|
| 535 |
+
"execution_count": 29,
|
| 536 |
+
"id": "aeaaab90-6a2a-4851-a1ca-28c54a446573",
|
| 537 |
+
"metadata": {},
|
| 538 |
+
"outputs": [
|
| 539 |
+
{
|
| 540 |
+
"name": "stdout",
|
| 541 |
+
"output_type": "stream",
|
| 542 |
+
"text": [
|
| 543 |
+
"tensor(float)\n",
|
| 544 |
+
"torch.float32\n",
|
| 545 |
+
"Input Name: modelInput\n",
|
| 546 |
+
"Output Name: modelOutput\n",
|
| 547 |
+
"[array([[ 4.3262615, -4.3409047],\n",
|
| 548 |
+
" [ 4.9648395, -4.968621 ],\n",
|
| 549 |
+
" [ 5.5126643, -5.522872 ],\n",
|
| 550 |
+
" [ 4.7735534, -4.8004475],\n",
|
| 551 |
+
" [ 4.0924144, -4.112945 ],\n",
|
| 552 |
+
" [ 4.588802 , -4.6043544],\n",
|
| 553 |
+
" [ 4.6231914, -4.617625 ],\n",
|
| 554 |
+
" [ 5.229881 , -5.2555394],\n",
|
| 555 |
+
" [ 4.877381 , -4.882144 ],\n",
|
| 556 |
+
" [ 5.2514744, -5.2786503],\n",
|
| 557 |
+
" [ 4.2948875, -4.3169603],\n",
|
| 558 |
+
" [ 4.5997186, -4.6177607],\n",
|
| 559 |
+
" [ 4.9509926, -4.9685597],\n",
|
| 560 |
+
" [ 4.933158 , -4.9568825],\n",
|
| 561 |
+
" [ 4.747336 , -4.7639017],\n",
|
| 562 |
+
" [ 5.020595 , -5.0202913],\n",
|
| 563 |
+
" [ 4.914437 , -4.9206715],\n",
|
| 564 |
+
" [ 5.193108 , -5.1925435],\n",
|
| 565 |
+
" [ 4.5233765, -4.512763 ],\n",
|
| 566 |
+
" [ 4.7573333, -4.762632 ],\n",
|
| 567 |
+
" [ 5.268702 , -5.2838397],\n",
|
| 568 |
+
" [ 4.857734 , -4.8605857],\n",
|
| 569 |
+
" [ 5.1886744, -5.2047734],\n",
|
| 570 |
+
" [ 5.512568 , -5.5503583],\n",
|
| 571 |
+
" [ 5.320961 , -5.344709 ],\n",
|
| 572 |
+
" [ 4.1023226, -4.1073256],\n",
|
| 573 |
+
" [ 5.17857 , -5.185736 ],\n",
|
| 574 |
+
" [ 4.997028 , -4.9933476],\n",
|
| 575 |
+
" [ 4.771303 , -4.767269 ],\n",
|
| 576 |
+
" [ 5.312805 , -5.3265243],\n",
|
| 577 |
+
" [ 5.0030336, -5.0492 ],\n",
|
| 578 |
+
" [ 5.429731 , -5.4249325]], dtype=float32)]\n"
|
| 579 |
+
]
|
| 580 |
+
}
|
| 581 |
+
],
|
| 582 |
+
"source": [
|
| 583 |
+
"import onnxruntime as ort\n",
|
| 584 |
+
"import onnx\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"# Path to your ONNX model\n",
|
| 587 |
+
"model_path = \"models/model-47-99.125.onnx\"\n",
|
| 588 |
+
"\n",
|
| 589 |
+
"# Load the ONNX model\n",
|
| 590 |
+
"session = ort.InferenceSession(model_path)\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"# Get input and output details\n",
|
| 593 |
+
"input_name = session.get_inputs()[0].name\n",
|
| 594 |
+
"output_name = session.get_outputs()[0].name\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"print(session.get_inputs()[0].type)\n",
|
| 597 |
+
"print(test_in.dtype)\n",
|
| 598 |
+
"\n",
|
| 599 |
+
"print(f\"Input Name: {input_name}\")\n",
|
| 600 |
+
"print(f\"Output Name: {output_name}\")\n",
|
| 601 |
+
"\n",
|
| 602 |
+
"# Example Input Data (Replace with your actual input data)\n",
|
| 603 |
+
"import numpy as np\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"# Perform inference\n",
|
| 606 |
+
"outputs = session.run([output_name], {input_name: intermediate.cpu().numpy()})\n",
|
| 607 |
+
"print(outputs)\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"onnx_model = onnx.load(model_path)"
|
| 610 |
+
]
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"cell_type": "code",
|
| 614 |
+
"execution_count": 30,
|
| 615 |
+
"id": "f250739d-4c8a-4752-964a-d0b929c396f4",
|
| 616 |
+
"metadata": {},
|
| 617 |
+
"outputs": [],
|
| 618 |
+
"source": [
|
| 619 |
+
"# import onnxruntime as ort\n",
|
| 620 |
+
"# import onnx\n",
|
| 621 |
+
"\n",
|
| 622 |
+
"# # Path to your ONNX model\n",
|
| 623 |
+
"# model_path = \"models_mask/preproc_test.onnx\"\n",
|
| 624 |
+
"\n",
|
| 625 |
+
"# # Load the ONNX model\n",
|
| 626 |
+
"# session = ort.InferenceSession(model_path)\n",
|
| 627 |
+
"\n",
|
| 628 |
+
"# # Get input and output details\n",
|
| 629 |
+
"# input_name = session.get_inputs()[0].name\n",
|
| 630 |
+
"# output_name = session.get_outputs()[0].name\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"# print(session.get_inputs()[0].type)\n",
|
| 633 |
+
"# print(test_in.dtype)\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# print(f\"Input Name: {input_name}\")\n",
|
| 636 |
+
"# print(f\"Output Name: {output_name}\")\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"# # Example Input Data (Replace with your actual input data)\n",
|
| 639 |
+
"# import numpy as np\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"# # Perform inference\n",
|
| 642 |
+
"# outputs = session.run([output_name], {input_name: test_in.cpu().numpy()})\n",
|
| 643 |
+
"# print(\"Model Output:\", outputs)\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"# onnx_model = onnx.load(model_path)"
|
| 646 |
+
]
|
| 647 |
+
},
|
| 648 |
+
{
|
| 649 |
+
"cell_type": "code",
|
| 650 |
+
"execution_count": 8,
|
| 651 |
+
"id": "24fed4e7-4838-44cc-9c3a-0862bdbe173a",
|
| 652 |
+
"metadata": {},
|
| 653 |
+
"outputs": [
|
| 654 |
+
{
|
| 655 |
+
"data": {
|
| 656 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAey0lEQVR4nO3df0xd9f3H8deFymXMcpURL8WCzG124o/L5Jd0dpblboQ6snZZxvaHItm6ZcFFc6NL+w+4rJMsMUiynIXtmyDZr8gaIy5zqanXH/gDQwvFVfEXjhkWvZc26r3luoBezvePxau0UHvhlvs5nOcjuX/ccw/nvDkhl2fuPedej23btgAAAAyRk+0BAAAAPok4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGCUTdkeIF2Li4t66623tHnzZnk8nmyPAwAAzoFt2zp16pRKS0uVk3P210YcFydvvfWWysrKsj0GAABYhZmZGW3duvWs6zgmTizLkmVZ+vDDDyX975crLCzM8lQAAOBcxONxlZWVafPmzZ+6rsdp360Tj8fl8/kUi8WIEwAAHCKd/9+cEAsAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMQpwAAACjECcAAMAoxAkAADBKVuJkenpajY2Nqqys1DXXXKNEIpGNMQAAgIE2ZWOnt956qw4cOKAdO3bonXfekdfrzcYYy7vbd9r9WHbmAADApdY9Tl566SVdcMEF2rFjhySpqKhovUcAAAAGS/ttneHhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo6mHnv99dd14YUXqqWlRdddd53uueeeNf0CAABgY0k7ThKJhAKBgCzLWvbxwcFBhUIhdXV1aXx8XIFAQE1NTZqdnZUkffjhh3r66af129/+ViMjIzp8+LAOHz68tt8CAABsGGnHSXNzsw4cOKA9e/Ys+3hPT4/27t2r9vZ2VVZWqq+vTwUFBerv75ckXXrppaqpqVFZWZm8Xq927dqliYmJFfc3Pz+veDy+5AYAADaujF6ts7CwoLGxMQWDwY93kJOjYDCokZERSVJtba1mZ2f17rvvanFxUcPDw7ryyitX3GZ3d7d8Pl/qVlZWlsmRAQCAYTIaJydPnlQymZTf71+y3O/3KxKJSJI2bdqke+65R1/72td07bXX6ktf+pK+9a1vrbjN/fv3KxaLpW4zMzOZHBkAABgmK5cSNzc3q7m5+ZzW9Xq9Zl1qDAAAzquMvnJSXFys3NxcRaPRJcuj0ahKSkrWtG3LslRZWana2to1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw2E1NDSsadsdHR2anJzUkSNH1jomAAAwWNpv68zNzWlqaip1f3p6WhMTEyoqKlJ5eblCoZDa2tpUU1Ojuro69fb2KpFIqL29PaODAwCAjSntODl69KgaGxtT90OhkCSpra1NAwMDam1t1YkTJ9TZ2alIJKKqqiodOnTojJNkAQAAluOxbdvO9hDnwrIsWZalZDKp1157TbFYTIWFhZnfEd+tAwBAxsXjcfl8vnP6/52VbyVeDc45AQDAHRwTJwAAwB0cEydcSgwAgDs4Jk54WwcAAHdwTJwAAAB3IE4AAIBRiBMAAGAUx8QJJ8QCAOAOjokTTogFAMAdHBMnAADAHYgTAABgFOIEAAAYxTFxwgmxAAC4g2PihBNiAQBwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAURwTJ1ytAwCAOzgmTrhaBwAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFMfECR/CBgCAOzgmTvgQNgAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwsfXAwDgDo6JEz6+HgAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABhlUzZ2WlFRocLCQuXk5Ojiiy/WE088kY0xAACAgbISJ5L03HPP6cILL8zW7gEAgKF4WwcAABgl7TgZHh5WS0uLSktL5fF4NDQ0dMY6lmWpoqJC+fn5qq+v1+jo6JLHPR6PbrzxRtXW1urPf/7zqocHAAAbT9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsbGqdZ555RmNjY/rb3/6me+65R//85z9X/xsAAIANJe04aW5u1oEDB7Rnz55lH+/p6dHevXvV3t6uyspK9fX1qaCgQP39/al1Lr30UknSli1btGvXLo2Pj6+4v/n5ecXj8SU3AACwcWX0nJOFhQWNjY0pGAx+vIOcHAWDQY2MjEj63ysvp06dkiTNzc3p8ccf11VXXbXiNru7u+Xz+VK3srKyTI4MAAAMk9E4OXnypJLJpPx+/5Llfr9fkUhEkhSNRnXDDTcoEAjo+uuv1y233KLa2toVt7l//37FYrHUbWZmJpMjAwAAw6z7pcSXX365XnjhhXNe3+v1yuv1nseJAACASTL6yklxcbFyc3MVjUaXLI9GoyopKVnTti3LUmVl5VlfZQEAAM6X0TjJy8tTdXW1wuFwatni4qLC4bAaGhrWtO2Ojg5NTk7qyJEjax0TAAAYLO23debm5jQ1NZW6Pz09rYmJCRUVFam8vFyhUEhtbW2qqalRXV2dent7lUgk1N7entHBAQDAxpR2nBw9elSNjY2p+6FQSJLU1tamgYEBtba26sSJE+rs7FQkElFVVZUOHTp0xkmy6bIsS5ZlKZlMrmk7AADAbB7btu1sD5GOeDwun8+nWCymwsLCzO/gbt9p92OZ3wcAAC6Tzv9vvlsHAAAYhTgBAABGcUyccCkxAADu4Jg44VJiAADcwTFxAgAA3IE4AQAARnFMnHDOCQAA7uCYOOGcEwAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTHxAlX6wAA4A6OiROu1gEAwB0cEycAAMAdiBMAAGAU4gQAABiFOAEAAEZxTJxwtQ4AAO7gmDjhah0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEZxTJzwOScAALiDY+KEzzkBAMAdHBMnAADAHYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTHxAnfrQMAgDs4Jk74bh0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABglKzFyfvvv6/LLrtMd955Z7ZGAAAABspanPzqV7/S9ddfn63dAwAAQ2UlTl5//XW98soram5uzsbuAQCAwdKOk+HhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo4uefzOO+9Ud3f3qocGAAAbV9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsrCTp4Ycf1hVXXKErrrhibZMDAIANaVO6P9Dc3HzWt2N6enq0d+9etbe3S5L6+vr0yCOPqL+/X/v27dPzzz+vBx54QAcPHtTc3Jw++OADFRYWqrOzc9ntzc/Pa35+PnU/Ho+nOzIAAHCQjJ5zsrCwoLGxMQWDwY93kJOjYDCokZERSVJ3d7dmZmb073//W/fee6/27t27Yph8tL7P50vdysrKMjkyAAAwTEbj5OTJk0omk/L7/UuW+/1+RSKRVW1z//79isViqdvMzEwmRgUAAIZK+22dTLr11ls/dR2v1yuv13v+hwEAAEbI6CsnxcXFys3NVTQaXbI8Go2qpKRkTdu2LEuVlZWqra1d03YAAIDZMhoneXl5qq6uVjgcTi1bXFxUOBxWQ0PDmrbd0dGhyclJHTlyZK1jAgAAg6X9ts7c3JympqZS96enpzUxMaGioiKVl5crFAqpra1NNTU1qqurU29vrxKJROrqHQAAgLNJO06OHj2qxsbG1P1QKCRJamtr08DAgFpbW3XixAl1dnYqEomoqqpKhw4dOuMk2XRZliXLspRMJte0HQAAYDaPbdt2todIRzwel8/nUywWU2FhYeZ3cLfvtPuxzO8DAACXSef/d9a++A8AAGA5xAkAADCKY+KES4kBAHAHx8QJlxIDAOAOjokTAADgDsQJAAAwimPihHNOAABwB8fECeecAADgDo6JEwAA4A7ECQAAMApxAgAAjOKYOOGEWAAA3MExccIJsQAAuINj4gQAALgDcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oRLiQEAcAfHxAmXEgMA4A6OiRMAAOAOxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oTPOQEAwB0cEyd8zgkAAO7gmDgBAADuQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjOKYOOHj6wEAcAfHxAkfXw8AgDs4Jk4AAIA7ECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMsu5x8t5776mmpkZVVVW6+uqr9X//93/rPQIAADDYpvXe4ebNmzU8PKyCggIlEgldffXV+s53vqPPfe5z6z0KAAAw0Lq/cpKbm6uCggJJ0vz8vGzblm3b6z0GAAAwVNpxMjw8rJaWFpWWlsrj8WhoaOiMdSzLUkVFhfLz81VfX6/R0dElj7/33nsKBALaunWr7rrrLhUXF6/6FwAAABtL2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Oxsap2LLrpIL7zwgqanp/WXv/xF0Wh09b8BAADYUNKOk+bmZh04cEB79uxZ9vGenh7t3btX7e3tqqysVF9fnwoKCtTf33/Gun6/X4FAQE8//fSK+5ufn1c8Hl9yAwAAG1dGzzlZWFjQ2NiYgsHgxzvIyVEwGNTIyIgkKRqN6tSpU5KkWCym4eFhbdu2bcVtdnd3y+fzpW5lZWWZHBkAABgmo3Fy8uRJJZNJ+f3+Jcv9fr8ikYgk6c0339SOHTsUCAS0Y8cO/exnP9M111yz4jb379+vWCyWus3MzGRyZAAAYJh1v5S4rq5OExMT57y+1+uV1+s9fwMBAACjZPSVk+LiYuXm5p5xgms0GlVJScmatm1ZliorK1VbW7um7QAAALNlNE7y8vJUXV2tcDicWra4uKhwOKyGhoY1bbujo0OTk5M6cuTIWscEAAAGS/ttnbm5OU1NTaXuT09Pa2JiQkVFRSovL1coFFJbW5tqampUV1en3t5eJRIJtbe3Z3RwAACwMaUdJ0ePHlVjY2PqfigUkiS1tbVpYGBAra2tOnHihDo7OxWJRFRVVaVDhw6dcZIsAADAcjy2Qz473rIsWZalZDKp1157TbFYTIWFhZnf0d2+0+7HMr8PAABcJh6Py+fzndP/73X/bp3V4pwTAADcwTFxAgAA3MExccKlxAAAuINj4oS3dQAAcAfHxAkAAHAH4gQAABiFOAEAAEZxTJxwQiwAAO7gmDjhhFgAANzBMXECAADcgTgBAABGIU4AAIBRHBMnnBALAIA7OCZOOCEWAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYxTFxwtU6AAC4g2PihKt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwuecAADgDo6JEz7nBAAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfLcOAADu4Jg44bt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFHWPU5mZma0c+dOVVZW6tprr9XBgwfXewQAAGCwTeu+w02b1Nvbq6qqKkUiEVVXV2vXrl367Gc/u96jAAAAA617nGzZskVbtmyRJJWUlKi4uFjvvPMOcQIAACSt4m2d4eFhtbS0qLS0VB6PR0NDQ2esY1mWKioqlJ+fr/r6eo2Oji67rbGxMSWTSZWVlaU9OAAA2JjSjpNEIqFAICDLspZ9fHBwUKFQSF1dXRofH1cgEFBTU5NmZ2eXrPfOO+/olltu0e9///vVTQ4AADaktN/WaW5uVnNz84qP9/T0aO/evWpvb5ck9fX16ZFHHlF/f7/27dsnSZqfn9fu3bu1b98+bd++/az7m5+f1/z8fOp+PB5Pd2QAAOAgGb1aZ2FhQWNjYwoGgx/vICdHwWBQIyMjkiTbtnXrrbfq61//um6++eZP3WZ3d7d8Pl/qxltAAABsbBmNk5MnTyqZTMrv9y9Z7vf7FYlEJEnPPvusBgcHNTQ0pKqqKlVVVen48eMrbnP//v2KxWKp28zMTCZHBgAAhln3q3VuuOEGLS4unvP6Xq9XXq/3PE4EAABMktFXToqLi5Wbm6toNLpkeTQaVUlJyZq2bVmWKisrVVtbu6btAAAAs2U0TvLy8lRdXa1wOJxatri4qHA4rIaGhjVtu6OjQ5OTkzpy5MhaxwQAAAZL+22dubk5TU1Npe5PT09rYmJCRUVFKi8vVygUUltbm2pqalRXV6fe3l4lEonU1TsAAABnk3acHD16VI2Njan7oVBIktTW1qaBgQG1trbqxIkT6uzsVCQSUVVVlQ4dOnTGSbLpsixLlmUpmUyuaTsAAMBsHtu27WwPkY54PC6fz6dYLKbCwsLM7+Bu32n3Y5nfBwAALpPO/+91/1ZiAACAsyFOAACAURwTJ1xKDACAOzgmTriUGAAAd3BMnAAAAHcgTgAAgFEcEyeccwIAgDs4Jk445wQAAHdwTJwAAAB3IE4AAIBRiBMAAGAU4gQAABjFMXHC1ToAALiDY+KEq3UAAHAHx8QJAABwB+IEAAAYhTgBAABGIU4AAIBRHBMnXK0DAIA7OCZOuFoHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfM4JAADu4Jg44XNOAABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABjFMXHCd+sAAOAOjokTvlsHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGCUrcbJnzx5dfPHF+u53v5uN3QMAAINlJU5uv/12/eEPf8jGrgEAgOGyEic7d+7U5s2bs7FrAABguLTjZHh4WC0tLSotLZXH49HQ0NAZ61iWpYqKCuXn56u+vl6jo6OZmBUAALhA2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Ozsqgacn59XPB5fcgMAABtX2nHS3NysAwcOaM+ePcs+3tPTo71796q9vV2VlZXq6+tTQUGB+vv7VzVgd3e3fD5f6lZWVraq7QAAAGfI6DknCwsLGhsbUzAY/HgHOTkKBoMaGRlZ1Tb379+vWCyWus3MzGRqXAAAYKBNmdzYyZMnlUwm5ff7lyz3+/165ZVXUveDwaBeeOEFJRIJbd26VQcPHlRDQ8Oy2/R6vfJ6vZkcEwAAGCyjcXKuHnvssbR/xrIsWZalZDJ5HiYCAMCl7vYtsyy2/nN8Qkbf1ikuLlZubq6i0eiS5dFoVCUlJWvadkdHhyYnJ3XkyJE1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw+EV37YBAAD4pLTf1pmbm9PU1FTq/vT0tCYmJlRUVKTy8nKFQiG1tbWppqZGdXV16u3tVSKRUHt7e0YHBwAAG1PacXL06FE1Njam7odCIUlSW1ubBgYG1NraqhMnTqizs1ORSERVVVU6dOjQGSfJpotzTgAAcAePbdt2todIRzwel8/nUywWU2FhYeZ3cPqJQVk+KQgAgPNqnU6ITef/d1a+WwcAAGAlxAkAADCKY+LEsixVVlaqtrY226MAAIDzyDFxwuecAADgDo6JEwAA4A7ECQAAMIpj4oRzTgAAcAfHxAnnnAAA4A6OiRMAAOAOxAkAADBK2t+tk20ffdp+PB4/PzuYP+3T/M/XfgAAMMHp//ek8/K/76P/2+fyrTmO+W6dj774b2FhQW+88Ua2xwEAAKswMzOjrVu3nnUdx8TJRxYXF/XWW29p8+bN8ng8Gd12PB5XWVmZZmZmzs+XCroIxzJzOJaZw7HMHI5l5rjlWNq2rVOnTqm0tFQ5OWc/q8Rxb+vk5OR8anGtVWFh4Yb+A1lPHMvM4VhmDscycziWmeOGY+nzLfMNyMvghFgAAGAU4gQAABiFOPkEr9errq4ueb3ebI/ieBzLzOFYZg7HMnM4lpnDsTyT406IBQAAGxuvnAAAAKMQJwAAwCjECQAAMApxAgAAjOK6OLEsSxUVFcrPz1d9fb1GR0fPuv7Bgwf15S9/Wfn5+brmmmv0j3/8Y50mNV86x3JgYEAej2fJLT8/fx2nNdPw8LBaWlpUWloqj8ejoaGhT/2ZJ598Utddd528Xq+++MUvamBg4LzP6QTpHssnn3zyjL9Jj8ejSCSyPgMbrLu7W7W1tdq8ebMuueQS7d69W6+++uqn/hzPl2dazbHk+dJlcTI4OKhQKKSuri6Nj48rEAioqalJs7Ozy67/3HPP6Qc/+IF++MMf6tixY9q9e7d2796tF198cZ0nN0+6x1L636cfvv3226nbm2++uY4TmymRSCgQCMiyrHNaf3p6WjfddJMaGxs1MTGhO+64Qz/60Y/06KOPnudJzZfusfzIq6++uuTv8pJLLjlPEzrHU089pY6ODj3//PM6fPiwPvjgA33zm99UIpFY8Wd4vlzeao6lxPOlbBepq6uzOzo6UveTyaRdWlpqd3d3L7v+9773Pfumm25asqy+vt7+yU9+cl7ndIJ0j+X9999v+3y+dZrOmSTZDz300FnX+fnPf25fddVVS5a1trbaTU1N53Ey5zmXY/nEE0/Ykux33313XWZystnZWVuS/dRTT624Ds+X5+ZcjiXPl7btmldOFhYWNDY2pmAwmFqWk5OjYDCokZGRZX9mZGRkyfqS1NTUtOL6brGaYylJc3Nzuuyyy1RWVqZvf/vbeumll9Zj3A2Fv8nMq6qq0pYtW/SNb3xDzz77bLbHMVIsFpMkFRUVrbgOf5vn5lyOpcTzpWvi5OTJk0omk/L7/UuW+/3+Fd9jjkQiaa3vFqs5ltu2bVN/f78efvhh/elPf9Li4qK2b9+u//znP+sx8oax0t9kPB7Xf//73yxN5UxbtmxRX1+fHnzwQT344IMqKyvTzp07NT4+nu3RjLK4uKg77rhDX/3qV3X11VevuB7Pl5/uXI8lz5cO/FZiOFNDQ4MaGhpS97dv364rr7xSv/vd7/TLX/4yi5PBrbZt26Zt27al7m/fvl1vvPGG7rvvPv3xj3/M4mRm6ejo0Isvvqhnnnkm26M43rkeS54vXfTKSXFxsXJzcxWNRpcsj0ajKikpWfZnSkpK0lrfLVZzLE93wQUX6Ctf+YqmpqbOx4gb1kp/k4WFhfrMZz6Tpak2jrq6Ov4mP+G2227T3//+dz3xxBPaunXrWdfl+fLs0jmWp3Pj86Vr4iQvL0/V1dUKh8OpZYuLiwqHw0sK9ZMaGhqWrC9Jhw8fXnF9t1jNsTxdMpnU8ePHtWXLlvM15obE3+T5NTExwd+kJNu2ddttt+mhhx7S448/rs9//vOf+jP8bS5vNcfydK58vsz2Gbnr6YEHHrC9Xq89MDBgT05O2j/+8Y/tiy66yI5EIrZt2/bNN99s79u3L7X+s88+a2/atMm+99577Zdfftnu6uqyL7jgAvv48ePZ+hWMke6x/MUvfmE/+uij9htvvGGPjY3Z3//+9+38/Hz7pZdeytavYIRTp07Zx44ds48dO2ZLsnt6euxjx47Zb775pm3btr1v3z775ptvTq3/r3/9yy4oKLDvuusu++WXX7Yty7Jzc3PtQ4cOZetXMEa6x/K+++6zh4aG7Ndff90+fvy4ffvtt9s5OTn2Y489lq1fwRg//elPbZ/PZz/55JP222+/nbq9//77qXV4vjw3qzmWPF/atqvixLZt+ze/+Y1dXl5u5+Xl2XV1dfbzzz+feuzGG2+029ralqz/17/+1b7iiivsvLw8+6qrrrIfeeSRdZ7YXOkcyzvuuCO1rt/vt3ft2mWPj49nYWqzfHQ56+m3j45dW1ubfeONN57xM1VVVXZeXp59+eWX2/fff/+6z22idI/lr3/9a/sLX/iCnZ+fbxcVFdk7d+60H3/88ewMb5jljqOkJX9rPF+em9UcS54vbdtj27a9fq/TAAAAnJ1rzjkBAADOQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwyv8D6KAeY7AISbEAAAAASUVORK5CYII=",
|
| 657 |
+
"text/plain": [
|
| 658 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 659 |
+
]
|
| 660 |
+
},
|
| 661 |
+
"metadata": {},
|
| 662 |
+
"output_type": "display_data"
|
| 663 |
+
}
|
| 664 |
+
],
|
| 665 |
+
"source": [
|
| 666 |
+
"import matplotlib.pyplot as plt\n",
|
| 667 |
+
"%matplotlib inline\n",
|
| 668 |
+
"plt.hist(abs(intermediate-outputs[0]).ravel(), bins = 100)\n",
|
| 669 |
+
"plt.yscale('log')\n",
|
| 670 |
+
"plt.show()"
|
| 671 |
+
]
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"cell_type": "code",
|
| 675 |
+
"execution_count": 15,
|
| 676 |
+
"id": "71cb219e-b91a-4629-99f6-00db786903c7",
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"outputs": [
|
| 679 |
+
{
|
| 680 |
+
"data": {
|
| 681 |
+
"text/plain": [
|
| 682 |
+
"tensor([1.1902e-03, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00,\n",
|
| 683 |
+
" 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00])"
|
| 684 |
+
]
|
| 685 |
+
},
|
| 686 |
+
"execution_count": 15,
|
| 687 |
+
"metadata": {},
|
| 688 |
+
"output_type": "execute_result"
|
| 689 |
+
}
|
| 690 |
+
],
|
| 691 |
+
"source": [
|
| 692 |
+
"torch.sort(abs(intermediate-outputs[0]).ravel())[0][-10:]"
|
| 693 |
+
]
|
| 694 |
+
},
|
| 695 |
+
{
|
| 696 |
+
"cell_type": "code",
|
| 697 |
+
"execution_count": null,
|
| 698 |
+
"id": "92ba5920-5451-4bb1-af0e-5ea987841ab1",
|
| 699 |
+
"metadata": {},
|
| 700 |
+
"outputs": [],
|
| 701 |
+
"source": [
|
| 702 |
+
"import onnxruntime as ort\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"session_options = ort.SessionOptions()\n",
|
| 705 |
+
"session_options.log_severity_level = 0 # Verbose logging\n",
|
| 706 |
+
"session = ort.InferenceSession(\"models_mask/preproc_test.onnx\", sess_options=session_options)"
|
| 707 |
+
]
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
"cell_type": "code",
|
| 711 |
+
"execution_count": null,
|
| 712 |
+
"id": "3277b343-245d-4ac8-a91c-373061dcbf53",
|
| 713 |
+
"metadata": {},
|
| 714 |
+
"outputs": [],
|
| 715 |
+
"source": [
|
| 716 |
+
"import matplotlib.pyplot as plt\n",
|
| 717 |
+
"%matplotlib inline\n",
|
| 718 |
+
"plt.imshow(outputs[0][0,8,:,:])\n",
|
| 719 |
+
"plt.show()"
|
| 720 |
+
]
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"cell_type": "code",
|
| 724 |
+
"execution_count": null,
|
| 725 |
+
"id": "67e99ef8-e49a-4037-a818-244555b0bdc5",
|
| 726 |
+
"metadata": {},
|
| 727 |
+
"outputs": [],
|
| 728 |
+
"source": [
|
| 729 |
+
"import onnx\n",
|
| 730 |
+
"\n",
|
| 731 |
+
"# Path to your ONNX model\n",
|
| 732 |
+
"model_path = \"models/model-47-99.125.onnx\"\n",
|
| 733 |
+
"\n",
|
| 734 |
+
"# Load the ONNX model\n",
|
| 735 |
+
"onnx_model = onnx.load(model_path)\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"# Check the model for validity\n",
|
| 738 |
+
"onnx.checker.check_model(onnx_model)\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"# Print model graph structure (optional)\n",
|
| 741 |
+
"print(onnx.helper.printable_graph(onnx_model.graph))\n"
|
| 742 |
+
]
|
| 743 |
+
}
|
| 744 |
+
],
|
| 745 |
+
"metadata": {
|
| 746 |
+
"kernelspec": {
|
| 747 |
+
"display_name": "Python 3 (ipykernel)",
|
| 748 |
+
"language": "python",
|
| 749 |
+
"name": "python3"
|
| 750 |
+
},
|
| 751 |
+
"language_info": {
|
| 752 |
+
"codemirror_mode": {
|
| 753 |
+
"name": "ipython",
|
| 754 |
+
"version": 3
|
| 755 |
+
},
|
| 756 |
+
"file_extension": ".py",
|
| 757 |
+
"mimetype": "text/x-python",
|
| 758 |
+
"name": "python",
|
| 759 |
+
"nbconvert_exporter": "python",
|
| 760 |
+
"pygments_lexer": "ipython3",
|
| 761 |
+
"version": "3.11.9"
|
| 762 |
+
}
|
| 763 |
+
},
|
| 764 |
+
"nbformat": 4,
|
| 765 |
+
"nbformat_minor": 5
|
| 766 |
+
}
|
models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"6\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
" 0%| | 0/48 [00:22<?, ?it/s]\n"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"ename": "AttributeError",
|
| 29 |
+
"evalue": "Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n",
|
| 30 |
+
"output_type": "error",
|
| 31 |
+
"traceback": [
|
| 32 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 33 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 34 |
+
"Cell \u001b[0;32mIn[1], line 50\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[0;32m---> 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 51\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmax(outputs, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 52\u001b[0m results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mextend(outputs\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mtolist())\n",
|
| 35 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
| 36 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 37 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:186\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 185\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 186\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparallel_apply(replicas, inputs, module_kwargs)\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
|
| 38 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:201\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 201\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parallel_apply(replicas, inputs, kwargs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(replicas)])\n",
|
| 39 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:108\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 106\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m--> 108\u001b[0m output\u001b[38;5;241m.\u001b[39mreraise()\n\u001b[1;32m 109\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
|
| 40 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/_utils.py:706\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 703\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 704\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 706\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
|
| 41 |
+
"\u001b[0;31mAttributeError\u001b[0m: Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n"
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"source": [
|
| 46 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 47 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 48 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 49 |
+
"from tqdm import tqdm\n",
|
| 50 |
+
"import torch\n",
|
| 51 |
+
"import numpy as np\n",
|
| 52 |
+
"from resnet_model import ResidualBlock, ResNet\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"import torch.nn as nn\n",
|
| 55 |
+
"import torch.optim as optim\n",
|
| 56 |
+
"from tqdm import tqdm \n",
|
| 57 |
+
"import torch.nn.functional as F\n",
|
| 58 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 59 |
+
"import pickle\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"torch.manual_seed(1)\n",
|
| 62 |
+
"# torch.manual_seed(42)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 66 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 67 |
+
"print(num_gpus)\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 70 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"num_classes = 2\n",
|
| 73 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 76 |
+
"model = nn.DataParallel(model)\n",
|
| 77 |
+
"model = model.to(device)\n",
|
| 78 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 79 |
+
"print(\"num params \",params)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"model_1 = 'models/model-23-99.045.pt'\n",
|
| 82 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 83 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 84 |
+
"model = model.eval()\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# eval\n",
|
| 87 |
+
"val_loss = 0.0\n",
|
| 88 |
+
"correct_valid = 0\n",
|
| 89 |
+
"total = 0\n",
|
| 90 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 91 |
+
"model.eval()\n",
|
| 92 |
+
"with torch.no_grad():\n",
|
| 93 |
+
" for images, labels in tqdm(testloader):\n",
|
| 94 |
+
" inputs, labels = images.to(device), labels\n",
|
| 95 |
+
" outputs = model(inputs)\n",
|
| 96 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 97 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 98 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 99 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 100 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 101 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 102 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 103 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 104 |
+
" total += labels[0].size(0)\n",
|
| 105 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 106 |
+
"# Calculate training accuracy after each epoch\n",
|
| 107 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 108 |
+
"print(\"===========================\")\n",
|
| 109 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 110 |
+
"print(\"===========================\")\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"import pickle\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# Pickle the dictionary to a file\n",
|
| 115 |
+
"with open('models/test_42.pkl', 'wb') as f:\n",
|
| 116 |
+
" pickle.dump(results, f)"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"execution_count": null,
|
| 122 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"outputs": [],
|
| 125 |
+
"source": [
|
| 126 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 127 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 128 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 129 |
+
"from tqdm import tqdm\n",
|
| 130 |
+
"import torch\n",
|
| 131 |
+
"import numpy as np\n",
|
| 132 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 133 |
+
"import torch\n",
|
| 134 |
+
"import torch.nn as nn\n",
|
| 135 |
+
"import torch.optim as optim\n",
|
| 136 |
+
"from tqdm import tqdm \n",
|
| 137 |
+
"import torch.nn.functional as F\n",
|
| 138 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 139 |
+
"import pickle\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"torch.manual_seed(1)\n",
|
| 142 |
+
"# torch.manual_seed(42)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 146 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 147 |
+
"print(num_gpus)\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 150 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"num_classes = 2\n",
|
| 153 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 156 |
+
"model = nn.DataParallel(model)\n",
|
| 157 |
+
"model = model.to(device)\n",
|
| 158 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 159 |
+
"print(\"num params \",params)\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"model_1 = 'models/model-14-98.005.pt'\n",
|
| 163 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 164 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 165 |
+
"model = model.eval()\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"# eval\n",
|
| 168 |
+
"val_loss = 0.0\n",
|
| 169 |
+
"correct_valid = 0\n",
|
| 170 |
+
"total = 0\n",
|
| 171 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 172 |
+
"model.eval()\n",
|
| 173 |
+
"with torch.no_grad():\n",
|
| 174 |
+
" for images, labels in tqdm(testloader):\n",
|
| 175 |
+
" inputs, labels = images.to(device), labels\n",
|
| 176 |
+
" outputs = model(inputs)\n",
|
| 177 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 178 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 179 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 180 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 181 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 182 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 183 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 184 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 185 |
+
" total += labels[0].size(0)\n",
|
| 186 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 187 |
+
" \n",
|
| 188 |
+
"# Calculate training accuracy after each epoch\n",
|
| 189 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 190 |
+
"print(\"===========================\")\n",
|
| 191 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 192 |
+
"print(\"===========================\")\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"import pickle\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"# Pickle the dictionary to a file\n",
|
| 197 |
+
"with open('models/test_1.pkl', 'wb') as f:\n",
|
| 198 |
+
" pickle.dump(results, f)"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": null,
|
| 204 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 209 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 210 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 211 |
+
"from tqdm import tqdm\n",
|
| 212 |
+
"import torch\n",
|
| 213 |
+
"import numpy as np\n",
|
| 214 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 215 |
+
"import torch\n",
|
| 216 |
+
"import torch.nn as nn\n",
|
| 217 |
+
"import torch.optim as optim\n",
|
| 218 |
+
"from tqdm import tqdm \n",
|
| 219 |
+
"import torch.nn.functional as F\n",
|
| 220 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 221 |
+
"import pickle\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"torch.manual_seed(1)\n",
|
| 224 |
+
"# torch.manual_seed(42)\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 228 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 229 |
+
"print(num_gpus)\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 232 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"num_classes = 2\n",
|
| 235 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 238 |
+
"model = nn.DataParallel(model)\n",
|
| 239 |
+
"model = model.to(device)\n",
|
| 240 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 241 |
+
"print(\"num params \",params)\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"model_1 = 'models/model-28-98.955.pt'\n",
|
| 245 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 246 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 247 |
+
"model = model.eval()\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"# eval\n",
|
| 250 |
+
"val_loss = 0.0\n",
|
| 251 |
+
"correct_valid = 0\n",
|
| 252 |
+
"total = 0\n",
|
| 253 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 254 |
+
"model.eval()\n",
|
| 255 |
+
"with torch.no_grad():\n",
|
| 256 |
+
" for images, labels in tqdm(testloader):\n",
|
| 257 |
+
" inputs, labels = images.to(device), labels\n",
|
| 258 |
+
" outputs = model(inputs)\n",
|
| 259 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 260 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 261 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 262 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 263 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 264 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 265 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 266 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 267 |
+
" total += labels[0].size(0)\n",
|
| 268 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 269 |
+
" \n",
|
| 270 |
+
"# Calculate training accuracy after each epoch\n",
|
| 271 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 272 |
+
"print(\"===========================\")\n",
|
| 273 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 274 |
+
"print(\"===========================\")\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"import pickle\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"# Pickle the dictionary to a file\n",
|
| 279 |
+
"with open('models/test_7109.pkl', 'wb') as f:\n",
|
| 280 |
+
" pickle.dump(results, f)"
|
| 281 |
+
]
|
| 282 |
+
}
|
| 283 |
+
],
|
| 284 |
+
"metadata": {
|
| 285 |
+
"kernelspec": {
|
| 286 |
+
"display_name": "Python 3 (ipykernel)",
|
| 287 |
+
"language": "python",
|
| 288 |
+
"name": "python3"
|
| 289 |
+
},
|
| 290 |
+
"language_info": {
|
| 291 |
+
"codemirror_mode": {
|
| 292 |
+
"name": "ipython",
|
| 293 |
+
"version": 3
|
| 294 |
+
},
|
| 295 |
+
"file_extension": ".py",
|
| 296 |
+
"mimetype": "text/x-python",
|
| 297 |
+
"name": "python",
|
| 298 |
+
"nbconvert_exporter": "python",
|
| 299 |
+
"pygments_lexer": "ipython3",
|
| 300 |
+
"version": "3.11.9"
|
| 301 |
+
}
|
| 302 |
+
},
|
| 303 |
+
"nbformat": 4,
|
| 304 |
+
"nbformat_minor": 5
|
| 305 |
+
}
|
models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"2\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:39<00:00, 1.21it/s]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"===========================\n",
|
| 32 |
+
"accuracy: 98.82\n",
|
| 33 |
+
"===========================\n",
|
| 34 |
+
"False Positive Rate: 0.010\n",
|
| 35 |
+
"Precision: 0.990\n",
|
| 36 |
+
"Recall: 0.986\n",
|
| 37 |
+
"F1 Score: 0.988\n"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"name": "stderr",
|
| 42 |
+
"output_type": "stream",
|
| 43 |
+
"text": [
|
| 44 |
+
"\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 50 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 51 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 52 |
+
"from tqdm import tqdm\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"import numpy as np\n",
|
| 55 |
+
"from resnet_model import ResidualBlock, ResNet\n",
|
| 56 |
+
"import torch\n",
|
| 57 |
+
"import torch.nn as nn\n",
|
| 58 |
+
"import torch.optim as optim\n",
|
| 59 |
+
"from tqdm import tqdm \n",
|
| 60 |
+
"import torch.nn.functional as F\n",
|
| 61 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 62 |
+
"import pickle\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"torch.manual_seed(1)\n",
|
| 65 |
+
"# torch.manual_seed(42)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 69 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 70 |
+
"print(num_gpus)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 73 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"num_classes = 2\n",
|
| 76 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 79 |
+
"model = nn.DataParallel(model)\n",
|
| 80 |
+
"model = model.to(device)\n",
|
| 81 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 82 |
+
"print(\"num params \",params)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"model_1 = 'models/model-23-99.045.pt'\n",
|
| 85 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 86 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 87 |
+
"model = model.eval()\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# eval\n",
|
| 90 |
+
"val_loss = 0.0\n",
|
| 91 |
+
"correct_valid = 0\n",
|
| 92 |
+
"total = 0\n",
|
| 93 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 94 |
+
"model.eval()\n",
|
| 95 |
+
"with torch.no_grad():\n",
|
| 96 |
+
" for images, labels in tqdm(testloader):\n",
|
| 97 |
+
" inputs, labels = images.to(device), labels\n",
|
| 98 |
+
" outputs = model(inputs)\n",
|
| 99 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 100 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 101 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 102 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 103 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 104 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 105 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 106 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 107 |
+
" total += labels[0].size(0)\n",
|
| 108 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 109 |
+
"# Calculate training accuracy after each epoch\n",
|
| 110 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 111 |
+
"print(\"===========================\")\n",
|
| 112 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 113 |
+
"print(\"===========================\")\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"import pickle\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# Pickle the dictionary to a file\n",
|
| 118 |
+
"with open('models/test_42.pkl', 'wb') as f:\n",
|
| 119 |
+
" pickle.dump(results, f)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 123 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# Example binary labels\n",
|
| 126 |
+
"true = results['true'] # ground truth\n",
|
| 127 |
+
"pred = results['pred'] # predicted\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"# Compute metrics\n",
|
| 130 |
+
"precision = precision_score(true, pred)\n",
|
| 131 |
+
"recall = recall_score(true, pred)\n",
|
| 132 |
+
"f1 = f1_score(true, pred)\n",
|
| 133 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 134 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"# Compute FPR\n",
|
| 137 |
+
"fpr = fp / (fp + tn)\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 142 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 143 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": 2,
|
| 149 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 150 |
+
"metadata": {},
|
| 151 |
+
"outputs": [
|
| 152 |
+
{
|
| 153 |
+
"name": "stdout",
|
| 154 |
+
"output_type": "stream",
|
| 155 |
+
"text": [
|
| 156 |
+
"2\n",
|
| 157 |
+
"num params encoder 50840\n",
|
| 158 |
+
"num params 21496282\n"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"name": "stderr",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 166 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 167 |
+
"100%|███████████████████████████████████████████| 48/48 [00:51<00:00, 1.07s/it]"
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"name": "stdout",
|
| 172 |
+
"output_type": "stream",
|
| 173 |
+
"text": [
|
| 174 |
+
"===========================\n",
|
| 175 |
+
"accuracy: 97.185\n",
|
| 176 |
+
"===========================\n",
|
| 177 |
+
"False Positive Rate: 0.038\n",
|
| 178 |
+
"Precision: 0.963\n",
|
| 179 |
+
"Recall: 0.981\n",
|
| 180 |
+
"F1 Score: 0.972\n"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"name": "stderr",
|
| 185 |
+
"output_type": "stream",
|
| 186 |
+
"text": [
|
| 187 |
+
"\n"
|
| 188 |
+
]
|
| 189 |
+
}
|
| 190 |
+
],
|
| 191 |
+
"source": [
|
| 192 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 193 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 194 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 195 |
+
"from tqdm import tqdm\n",
|
| 196 |
+
"import torch\n",
|
| 197 |
+
"import numpy as np\n",
|
| 198 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 199 |
+
"import torch\n",
|
| 200 |
+
"import torch.nn as nn\n",
|
| 201 |
+
"import torch.optim as optim\n",
|
| 202 |
+
"from tqdm import tqdm \n",
|
| 203 |
+
"import torch.nn.functional as F\n",
|
| 204 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 205 |
+
"import pickle\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"torch.manual_seed(1)\n",
|
| 208 |
+
"# torch.manual_seed(42)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 212 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 213 |
+
"print(num_gpus)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 216 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"num_classes = 2\n",
|
| 219 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 222 |
+
"model = nn.DataParallel(model)\n",
|
| 223 |
+
"model = model.to(device)\n",
|
| 224 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 225 |
+
"print(\"num params \",params)\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"model_1 = 'models/model-14-98.005.pt'\n",
|
| 229 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 230 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 231 |
+
"model = model.eval()\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"# eval\n",
|
| 234 |
+
"val_loss = 0.0\n",
|
| 235 |
+
"correct_valid = 0\n",
|
| 236 |
+
"total = 0\n",
|
| 237 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 238 |
+
"model.eval()\n",
|
| 239 |
+
"with torch.no_grad():\n",
|
| 240 |
+
" for images, labels in tqdm(testloader):\n",
|
| 241 |
+
" inputs, labels = images.to(device), labels\n",
|
| 242 |
+
" outputs = model(inputs)\n",
|
| 243 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 244 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 245 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 246 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 247 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 248 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 249 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 250 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 251 |
+
" total += labels[0].size(0)\n",
|
| 252 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 253 |
+
" \n",
|
| 254 |
+
"# Calculate training accuracy after each epoch\n",
|
| 255 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 256 |
+
"print(\"===========================\")\n",
|
| 257 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 258 |
+
"print(\"===========================\")\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"import pickle\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"# Pickle the dictionary to a file\n",
|
| 263 |
+
"with open('models/test_1.pkl', 'wb') as f:\n",
|
| 264 |
+
" pickle.dump(results, f)\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 268 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"# Example binary labels\n",
|
| 271 |
+
"true = results['true'] # ground truth\n",
|
| 272 |
+
"pred = results['pred'] # predicted\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"# Compute metrics\n",
|
| 275 |
+
"precision = precision_score(true, pred)\n",
|
| 276 |
+
"recall = recall_score(true, pred)\n",
|
| 277 |
+
"f1 = f1_score(true, pred)\n",
|
| 278 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 279 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"# Compute FPR\n",
|
| 282 |
+
"fpr = fp / (fp + tn)\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 287 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 288 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": 3,
|
| 294 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [
|
| 297 |
+
{
|
| 298 |
+
"name": "stdout",
|
| 299 |
+
"output_type": "stream",
|
| 300 |
+
"text": [
|
| 301 |
+
"2\n",
|
| 302 |
+
"num params encoder 50840\n",
|
| 303 |
+
"num params 21496282\n"
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"name": "stderr",
|
| 308 |
+
"output_type": "stream",
|
| 309 |
+
"text": [
|
| 310 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 311 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 312 |
+
"100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.11s/it]"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"name": "stdout",
|
| 317 |
+
"output_type": "stream",
|
| 318 |
+
"text": [
|
| 319 |
+
"===========================\n",
|
| 320 |
+
"accuracy: 98.455\n",
|
| 321 |
+
"===========================\n",
|
| 322 |
+
"False Positive Rate: 0.010\n",
|
| 323 |
+
"Precision: 0.990\n",
|
| 324 |
+
"Recall: 0.979\n",
|
| 325 |
+
"F1 Score: 0.984\n"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"name": "stderr",
|
| 330 |
+
"output_type": "stream",
|
| 331 |
+
"text": [
|
| 332 |
+
"\n"
|
| 333 |
+
]
|
| 334 |
+
}
|
| 335 |
+
],
|
| 336 |
+
"source": [
|
| 337 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 338 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 339 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 340 |
+
"from tqdm import tqdm\n",
|
| 341 |
+
"import torch\n",
|
| 342 |
+
"import numpy as np\n",
|
| 343 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 344 |
+
"import torch\n",
|
| 345 |
+
"import torch.nn as nn\n",
|
| 346 |
+
"import torch.optim as optim\n",
|
| 347 |
+
"from tqdm import tqdm \n",
|
| 348 |
+
"import torch.nn.functional as F\n",
|
| 349 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 350 |
+
"import pickle\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"torch.manual_seed(1)\n",
|
| 353 |
+
"# torch.manual_seed(42)\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 357 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 358 |
+
"print(num_gpus)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 361 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"num_classes = 2\n",
|
| 364 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 367 |
+
"model = nn.DataParallel(model)\n",
|
| 368 |
+
"model = model.to(device)\n",
|
| 369 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 370 |
+
"print(\"num params \",params)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"model_1 = 'models/model-28-98.955.pt'\n",
|
| 374 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 375 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 376 |
+
"model = model.eval()\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"# eval\n",
|
| 379 |
+
"val_loss = 0.0\n",
|
| 380 |
+
"correct_valid = 0\n",
|
| 381 |
+
"total = 0\n",
|
| 382 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 383 |
+
"model.eval()\n",
|
| 384 |
+
"with torch.no_grad():\n",
|
| 385 |
+
" for images, labels in tqdm(testloader):\n",
|
| 386 |
+
" inputs, labels = images.to(device), labels\n",
|
| 387 |
+
" outputs = model(inputs)\n",
|
| 388 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 389 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 390 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 391 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 392 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 393 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 394 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 395 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 396 |
+
" total += labels[0].size(0)\n",
|
| 397 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 398 |
+
" \n",
|
| 399 |
+
"# Calculate training accuracy after each epoch\n",
|
| 400 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 401 |
+
"print(\"===========================\")\n",
|
| 402 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 403 |
+
"print(\"===========================\")\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"import pickle\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"# Pickle the dictionary to a file\n",
|
| 408 |
+
"with open('models/test_7109.pkl', 'wb') as f:\n",
|
| 409 |
+
" pickle.dump(results, f)\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 412 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"# Example binary labels\n",
|
| 415 |
+
"true = results['true'] # ground truth\n",
|
| 416 |
+
"pred = results['pred'] # predicted\n",
|
| 417 |
+
"\n",
|
| 418 |
+
"# Compute metrics\n",
|
| 419 |
+
"precision = precision_score(true, pred)\n",
|
| 420 |
+
"recall = recall_score(true, pred)\n",
|
| 421 |
+
"f1 = f1_score(true, pred)\n",
|
| 422 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 423 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"# Compute FPR\n",
|
| 426 |
+
"fpr = fp / (fp + tn)\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 429 |
+
"\n",
|
| 430 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 431 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 432 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 433 |
+
]
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"cell_type": "code",
|
| 437 |
+
"execution_count": 4,
|
| 438 |
+
"id": "ad4ef08f-6a9b-495f-bb93-0a13251f0adb",
|
| 439 |
+
"metadata": {},
|
| 440 |
+
"outputs": [
|
| 441 |
+
{
|
| 442 |
+
"name": "stdout",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
"98.15333333333332 0.7007416705811665\n",
|
| 446 |
+
"0.9803333333333333 0.0009428090415820641\n",
|
| 447 |
+
"0.9813333333333333 0.006798692684790386\n",
|
| 448 |
+
"0.019333333333333334 0.013199326582148887\n"
|
| 449 |
+
]
|
| 450 |
+
}
|
| 451 |
+
],
|
| 452 |
+
"source": [
|
| 453 |
+
"# acc\n",
|
| 454 |
+
"print(np.mean([98.82,97.185,98.455]), np.std([98.82,97.185,98.455]))\n",
|
| 455 |
+
"# recall\n",
|
| 456 |
+
"print(np.mean([0.981,0.981, 0.979]), np.std([0.981,0.981, 0.979]))\n",
|
| 457 |
+
"# f1\n",
|
| 458 |
+
"print(np.mean([0.988,0.972,0.984]),np.std([0.988,0.972,0.984]))\n",
|
| 459 |
+
"# fp\n",
|
| 460 |
+
"print(np.mean([0.010,0.038,0.010]),np.std([0.010,0.038,0.010]))"
|
| 461 |
+
]
|
| 462 |
+
}
|
| 463 |
+
],
|
| 464 |
+
"metadata": {
|
| 465 |
+
"kernelspec": {
|
| 466 |
+
"display_name": "Python 3 (ipykernel)",
|
| 467 |
+
"language": "python",
|
| 468 |
+
"name": "python3"
|
| 469 |
+
},
|
| 470 |
+
"language_info": {
|
| 471 |
+
"codemirror_mode": {
|
| 472 |
+
"name": "ipython",
|
| 473 |
+
"version": 3
|
| 474 |
+
},
|
| 475 |
+
"file_extension": ".py",
|
| 476 |
+
"mimetype": "text/x-python",
|
| 477 |
+
"name": "python",
|
| 478 |
+
"nbconvert_exporter": "python",
|
| 479 |
+
"pygments_lexer": "ipython3",
|
| 480 |
+
"version": "3.11.9"
|
| 481 |
+
}
|
| 482 |
+
},
|
| 483 |
+
"nbformat": 4,
|
| 484 |
+
"nbformat_minor": 5
|
| 485 |
+
}
|
models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"6\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
"100%|███████████████████████████████████████████| 48/48 [01:43<00:00, 2.16s/it]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"===========================\n",
|
| 32 |
+
"accuracy: 98.94\n",
|
| 33 |
+
"===========================\n"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"name": "stderr",
|
| 38 |
+
"output_type": "stream",
|
| 39 |
+
"text": [
|
| 40 |
+
"\n"
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
],
|
| 44 |
+
"source": [
|
| 45 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 46 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 47 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 48 |
+
"from tqdm import tqdm\n",
|
| 49 |
+
"import torch\n",
|
| 50 |
+
"import numpy as np\n",
|
| 51 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 52 |
+
"import torch\n",
|
| 53 |
+
"import torch.nn as nn\n",
|
| 54 |
+
"import torch.optim as optim\n",
|
| 55 |
+
"from tqdm import tqdm \n",
|
| 56 |
+
"import torch.nn.functional as F\n",
|
| 57 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 58 |
+
"import pickle\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"torch.manual_seed(1)\n",
|
| 61 |
+
"# torch.manual_seed(42)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 65 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 66 |
+
"print(num_gpus)\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 69 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"num_classes = 2\n",
|
| 72 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 75 |
+
"model = nn.DataParallel(model)\n",
|
| 76 |
+
"model = model.to(device)\n",
|
| 77 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 78 |
+
"print(\"num params \",params)\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"model_1 = 'models_8/model-25-99.31_7109.pt'\n",
|
| 81 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 82 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 83 |
+
"model = model.eval()\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# eval\n",
|
| 86 |
+
"val_loss = 0.0\n",
|
| 87 |
+
"correct_valid = 0\n",
|
| 88 |
+
"total = 0\n",
|
| 89 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 90 |
+
"model.eval()\n",
|
| 91 |
+
"with torch.no_grad():\n",
|
| 92 |
+
" for images, labels in tqdm(testloader):\n",
|
| 93 |
+
" inputs, labels = images.to(device), labels\n",
|
| 94 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 95 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 96 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 97 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 98 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 99 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 100 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 101 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 102 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 103 |
+
" total += labels[0].size(0)\n",
|
| 104 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" \n",
|
| 107 |
+
"# Calculate training accuracy after each epoch\n",
|
| 108 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 109 |
+
"print(\"===========================\")\n",
|
| 110 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 111 |
+
"print(\"===========================\")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"import pickle\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"# Pickle the dictionary to a file\n",
|
| 116 |
+
"with open('models_8/test_7109.pkl', 'wb') as f:\n",
|
| 117 |
+
" pickle.dump(results, f)"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": 2,
|
| 123 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [
|
| 126 |
+
{
|
| 127 |
+
"name": "stdout",
|
| 128 |
+
"output_type": "stream",
|
| 129 |
+
"text": [
|
| 130 |
+
"6\n",
|
| 131 |
+
"num params encoder 50840\n",
|
| 132 |
+
"num params 21496282\n"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"name": "stderr",
|
| 137 |
+
"output_type": "stream",
|
| 138 |
+
"text": [
|
| 139 |
+
"100%|██████████████████��████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"name": "stdout",
|
| 144 |
+
"output_type": "stream",
|
| 145 |
+
"text": [
|
| 146 |
+
"===========================\n",
|
| 147 |
+
"accuracy: 99.17\n",
|
| 148 |
+
"===========================\n"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"name": "stderr",
|
| 153 |
+
"output_type": "stream",
|
| 154 |
+
"text": [
|
| 155 |
+
"\n"
|
| 156 |
+
]
|
| 157 |
+
}
|
| 158 |
+
],
|
| 159 |
+
"source": [
|
| 160 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 161 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 162 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 163 |
+
"from tqdm import tqdm\n",
|
| 164 |
+
"import torch\n",
|
| 165 |
+
"import numpy as np\n",
|
| 166 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 167 |
+
"import torch\n",
|
| 168 |
+
"import torch.nn as nn\n",
|
| 169 |
+
"import torch.optim as optim\n",
|
| 170 |
+
"from tqdm import tqdm \n",
|
| 171 |
+
"import torch.nn.functional as F\n",
|
| 172 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 173 |
+
"import pickle\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"torch.manual_seed(1)\n",
|
| 176 |
+
"# torch.manual_seed(42)\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 180 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 181 |
+
"print(num_gpus)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 184 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"num_classes = 2\n",
|
| 187 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 190 |
+
"model = nn.DataParallel(model)\n",
|
| 191 |
+
"model = model.to(device)\n",
|
| 192 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 193 |
+
"print(\"num params \",params)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"model_1 = 'models_8/model-44-99.445_42.pt'\n",
|
| 197 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 198 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 199 |
+
"model = model.eval()\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"# eval\n",
|
| 202 |
+
"val_loss = 0.0\n",
|
| 203 |
+
"correct_valid = 0\n",
|
| 204 |
+
"total = 0\n",
|
| 205 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 206 |
+
"model.eval()\n",
|
| 207 |
+
"with torch.no_grad():\n",
|
| 208 |
+
" for images, labels in tqdm(testloader):\n",
|
| 209 |
+
" inputs, labels = images.to(device), labels\n",
|
| 210 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 211 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 212 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 213 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 214 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 215 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 216 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 217 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 218 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 219 |
+
" total += labels[0].size(0)\n",
|
| 220 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 221 |
+
" \n",
|
| 222 |
+
"# Calculate training accuracy after each epoch\n",
|
| 223 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 224 |
+
"print(\"===========================\")\n",
|
| 225 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 226 |
+
"print(\"===========================\")\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"import pickle\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"# Pickle the dictionary to a file\n",
|
| 231 |
+
"with open('models_8/test_42.pkl', 'wb') as f:\n",
|
| 232 |
+
" pickle.dump(results, f)"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": 3,
|
| 238 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"outputs": [
|
| 241 |
+
{
|
| 242 |
+
"name": "stdout",
|
| 243 |
+
"output_type": "stream",
|
| 244 |
+
"text": [
|
| 245 |
+
"6\n",
|
| 246 |
+
"num params encoder 50840\n",
|
| 247 |
+
"num params 21496282\n"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stderr",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"100%|███████████████████████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stdout",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"===========================\n",
|
| 262 |
+
"accuracy: 99.035\n",
|
| 263 |
+
"===========================\n"
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"name": "stderr",
|
| 268 |
+
"output_type": "stream",
|
| 269 |
+
"text": [
|
| 270 |
+
"\n"
|
| 271 |
+
]
|
| 272 |
+
}
|
| 273 |
+
],
|
| 274 |
+
"source": [
|
| 275 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 276 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 277 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 278 |
+
"from tqdm import tqdm\n",
|
| 279 |
+
"import torch\n",
|
| 280 |
+
"import numpy as np\n",
|
| 281 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 282 |
+
"import torch\n",
|
| 283 |
+
"import torch.nn as nn\n",
|
| 284 |
+
"import torch.optim as optim\n",
|
| 285 |
+
"from tqdm import tqdm \n",
|
| 286 |
+
"import torch.nn.functional as F\n",
|
| 287 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 288 |
+
"import pickle\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"torch.manual_seed(1)\n",
|
| 291 |
+
"# torch.manual_seed(42)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 295 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 296 |
+
"print(num_gpus)\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 299 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"num_classes = 2\n",
|
| 302 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 305 |
+
"model = nn.DataParallel(model)\n",
|
| 306 |
+
"model = model.to(device)\n",
|
| 307 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 308 |
+
"print(\"num params \",params)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"model_1 = 'models_8/model-43-99.355_1.pt'\n",
|
| 312 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 313 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 314 |
+
"model = model.eval()\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"# eval\n",
|
| 317 |
+
"val_loss = 0.0\n",
|
| 318 |
+
"correct_valid = 0\n",
|
| 319 |
+
"total = 0\n",
|
| 320 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 321 |
+
"model.eval()\n",
|
| 322 |
+
"with torch.no_grad():\n",
|
| 323 |
+
" for images, labels in tqdm(testloader):\n",
|
| 324 |
+
" inputs, labels = images.to(device), labels\n",
|
| 325 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 326 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 327 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 328 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 329 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 330 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 331 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 332 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 333 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 334 |
+
" total += labels[0].size(0)\n",
|
| 335 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 336 |
+
" \n",
|
| 337 |
+
"# Calculate training accuracy after each epoch\n",
|
| 338 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 339 |
+
"print(\"===========================\")\n",
|
| 340 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 341 |
+
"print(\"===========================\")\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"import pickle\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"# Pickle the dictionary to a file\n",
|
| 346 |
+
"with open('models_8/test_1.pkl', 'wb') as f:\n",
|
| 347 |
+
" pickle.dump(results, f)"
|
| 348 |
+
]
|
| 349 |
+
}
|
| 350 |
+
],
|
| 351 |
+
"metadata": {
|
| 352 |
+
"kernelspec": {
|
| 353 |
+
"display_name": "Python 3 (ipykernel)",
|
| 354 |
+
"language": "python",
|
| 355 |
+
"name": "python3"
|
| 356 |
+
},
|
| 357 |
+
"language_info": {
|
| 358 |
+
"codemirror_mode": {
|
| 359 |
+
"name": "ipython",
|
| 360 |
+
"version": 3
|
| 361 |
+
},
|
| 362 |
+
"file_extension": ".py",
|
| 363 |
+
"mimetype": "text/x-python",
|
| 364 |
+
"name": "python",
|
| 365 |
+
"nbconvert_exporter": "python",
|
| 366 |
+
"pygments_lexer": "ipython3",
|
| 367 |
+
"version": "3.11.9"
|
| 368 |
+
}
|
| 369 |
+
},
|
| 370 |
+
"nbformat": 4,
|
| 371 |
+
"nbformat_minor": 5
|
| 372 |
+
}
|
models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"2\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:44<00:00, 1.09it/s]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"===========================\n",
|
| 32 |
+
"accuracy: 98.94\n",
|
| 33 |
+
"===========================\n",
|
| 34 |
+
"False Positive Rate: 0.004\n",
|
| 35 |
+
"Precision: 0.996\n",
|
| 36 |
+
"Recall: 0.983\n",
|
| 37 |
+
"F1 Score: 0.989\n"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"name": "stderr",
|
| 42 |
+
"output_type": "stream",
|
| 43 |
+
"text": [
|
| 44 |
+
"\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 50 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 51 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 52 |
+
"from tqdm import tqdm\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"import numpy as np\n",
|
| 55 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 56 |
+
"import torch\n",
|
| 57 |
+
"import torch.nn as nn\n",
|
| 58 |
+
"import torch.optim as optim\n",
|
| 59 |
+
"from tqdm import tqdm \n",
|
| 60 |
+
"import torch.nn.functional as F\n",
|
| 61 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 62 |
+
"import pickle\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"torch.manual_seed(1)\n",
|
| 65 |
+
"# torch.manual_seed(42)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 69 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 70 |
+
"print(num_gpus)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 73 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"num_classes = 2\n",
|
| 76 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 79 |
+
"model = nn.DataParallel(model)\n",
|
| 80 |
+
"model = model.to(device)\n",
|
| 81 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 82 |
+
"print(\"num params \",params)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"model_1 = 'models_8/model-25-99.31_7109.pt'\n",
|
| 85 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 86 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 87 |
+
"model = model.eval()\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# eval\n",
|
| 90 |
+
"val_loss = 0.0\n",
|
| 91 |
+
"correct_valid = 0\n",
|
| 92 |
+
"total = 0\n",
|
| 93 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 94 |
+
"model.eval()\n",
|
| 95 |
+
"with torch.no_grad():\n",
|
| 96 |
+
" for images, labels in tqdm(testloader):\n",
|
| 97 |
+
" inputs, labels = images.to(device), labels\n",
|
| 98 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 99 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 100 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 101 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 102 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 103 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 104 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 105 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 106 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 107 |
+
" total += labels[0].size(0)\n",
|
| 108 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 109 |
+
" \n",
|
| 110 |
+
" \n",
|
| 111 |
+
"# Calculate training accuracy after each epoch\n",
|
| 112 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 113 |
+
"print(\"===========================\")\n",
|
| 114 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 115 |
+
"print(\"===========================\")\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"import pickle\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# Pickle the dictionary to a file\n",
|
| 120 |
+
"with open('models_8/test_7109.pkl', 'wb') as f:\n",
|
| 121 |
+
" pickle.dump(results, f)\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 124 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"# Example binary labels\n",
|
| 127 |
+
"true = results['true'] # ground truth\n",
|
| 128 |
+
"pred = results['pred'] # predicted\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"# Compute metrics\n",
|
| 131 |
+
"precision = precision_score(true, pred)\n",
|
| 132 |
+
"recall = recall_score(true, pred)\n",
|
| 133 |
+
"f1 = f1_score(true, pred)\n",
|
| 134 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 135 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# Compute FPR\n",
|
| 138 |
+
"fpr = fp / (fp + tn)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 143 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 144 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": 3,
|
| 150 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [
|
| 153 |
+
{
|
| 154 |
+
"name": "stdout",
|
| 155 |
+
"output_type": "stream",
|
| 156 |
+
"text": [
|
| 157 |
+
"2\n",
|
| 158 |
+
"num params encoder 50840\n",
|
| 159 |
+
"num params 21496282\n"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"name": "stderr",
|
| 164 |
+
"output_type": "stream",
|
| 165 |
+
"text": [
|
| 166 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 167 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 168 |
+
"100%|███████████████████████████████████████████| 48/48 [00:41<00:00, 1.15it/s]"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"name": "stdout",
|
| 173 |
+
"output_type": "stream",
|
| 174 |
+
"text": [
|
| 175 |
+
"===========================\n",
|
| 176 |
+
"accuracy: 99.17\n",
|
| 177 |
+
"===========================\n",
|
| 178 |
+
"False Positive Rate: 0.004\n",
|
| 179 |
+
"Precision: 0.995\n",
|
| 180 |
+
"Recall: 0.988\n",
|
| 181 |
+
"F1 Score: 0.992\n"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"name": "stderr",
|
| 186 |
+
"output_type": "stream",
|
| 187 |
+
"text": [
|
| 188 |
+
"\n"
|
| 189 |
+
]
|
| 190 |
+
}
|
| 191 |
+
],
|
| 192 |
+
"source": [
|
| 193 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 194 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 195 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 196 |
+
"from tqdm import tqdm\n",
|
| 197 |
+
"import torch\n",
|
| 198 |
+
"import numpy as np\n",
|
| 199 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 200 |
+
"import torch\n",
|
| 201 |
+
"import torch.nn as nn\n",
|
| 202 |
+
"import torch.optim as optim\n",
|
| 203 |
+
"from tqdm import tqdm \n",
|
| 204 |
+
"import torch.nn.functional as F\n",
|
| 205 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 206 |
+
"import pickle\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"torch.manual_seed(1)\n",
|
| 209 |
+
"# torch.manual_seed(42)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 213 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 214 |
+
"print(num_gpus)\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 217 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"num_classes = 2\n",
|
| 220 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 223 |
+
"model = nn.DataParallel(model)\n",
|
| 224 |
+
"model = model.to(device)\n",
|
| 225 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 226 |
+
"print(\"num params \",params)\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"model_1 = 'models_8/model-44-99.445_42.pt'\n",
|
| 230 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 231 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 232 |
+
"model = model.eval()\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"# eval\n",
|
| 235 |
+
"val_loss = 0.0\n",
|
| 236 |
+
"correct_valid = 0\n",
|
| 237 |
+
"total = 0\n",
|
| 238 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 239 |
+
"model.eval()\n",
|
| 240 |
+
"with torch.no_grad():\n",
|
| 241 |
+
" for images, labels in tqdm(testloader):\n",
|
| 242 |
+
" inputs, labels = images.to(device), labels\n",
|
| 243 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 244 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 245 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 246 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 247 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 248 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 249 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 250 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 251 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 252 |
+
" total += labels[0].size(0)\n",
|
| 253 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 254 |
+
" \n",
|
| 255 |
+
"# Calculate training accuracy after each epoch\n",
|
| 256 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 257 |
+
"print(\"===========================\")\n",
|
| 258 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 259 |
+
"print(\"===========================\")\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"import pickle\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"# Pickle the dictionary to a file\n",
|
| 264 |
+
"with open('models_8/test_42.pkl', 'wb') as f:\n",
|
| 265 |
+
" pickle.dump(results, f)\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"# Example binary labels\n",
|
| 270 |
+
"true = results['true'] # ground truth\n",
|
| 271 |
+
"pred = results['pred'] # predicted\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Compute metrics\n",
|
| 274 |
+
"precision = precision_score(true, pred)\n",
|
| 275 |
+
"recall = recall_score(true, pred)\n",
|
| 276 |
+
"f1 = f1_score(true, pred)\n",
|
| 277 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 278 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"# Compute FPR\n",
|
| 281 |
+
"fpr = fp / (fp + tn)\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 286 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 287 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 288 |
+
]
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"cell_type": "code",
|
| 292 |
+
"execution_count": 4,
|
| 293 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 294 |
+
"metadata": {},
|
| 295 |
+
"outputs": [
|
| 296 |
+
{
|
| 297 |
+
"name": "stdout",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"2\n",
|
| 301 |
+
"num params encoder 50840\n",
|
| 302 |
+
"num params 21496282\n"
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"name": "stderr",
|
| 307 |
+
"output_type": "stream",
|
| 308 |
+
"text": [
|
| 309 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 310 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 311 |
+
"100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.12s/it]"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"name": "stdout",
|
| 316 |
+
"output_type": "stream",
|
| 317 |
+
"text": [
|
| 318 |
+
"===========================\n",
|
| 319 |
+
"accuracy: 99.035\n",
|
| 320 |
+
"===========================\n",
|
| 321 |
+
"False Positive Rate: 0.010\n",
|
| 322 |
+
"Precision: 0.990\n",
|
| 323 |
+
"Recall: 0.990\n",
|
| 324 |
+
"F1 Score: 0.990\n"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"name": "stderr",
|
| 329 |
+
"output_type": "stream",
|
| 330 |
+
"text": [
|
| 331 |
+
"\n"
|
| 332 |
+
]
|
| 333 |
+
}
|
| 334 |
+
],
|
| 335 |
+
"source": [
|
| 336 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 337 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 338 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 339 |
+
"from tqdm import tqdm\n",
|
| 340 |
+
"import torch\n",
|
| 341 |
+
"import numpy as np\n",
|
| 342 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 343 |
+
"import torch\n",
|
| 344 |
+
"import torch.nn as nn\n",
|
| 345 |
+
"import torch.optim as optim\n",
|
| 346 |
+
"from tqdm import tqdm \n",
|
| 347 |
+
"import torch.nn.functional as F\n",
|
| 348 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 349 |
+
"import pickle\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"torch.manual_seed(1)\n",
|
| 352 |
+
"# torch.manual_seed(42)\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 356 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 357 |
+
"print(num_gpus)\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 360 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"num_classes = 2\n",
|
| 363 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 366 |
+
"model = nn.DataParallel(model)\n",
|
| 367 |
+
"model = model.to(device)\n",
|
| 368 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 369 |
+
"print(\"num params \",params)\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"model_1 = 'models_8/model-43-99.355_1.pt'\n",
|
| 373 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 374 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 375 |
+
"model = model.eval()\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"# eval\n",
|
| 378 |
+
"val_loss = 0.0\n",
|
| 379 |
+
"correct_valid = 0\n",
|
| 380 |
+
"total = 0\n",
|
| 381 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 382 |
+
"model.eval()\n",
|
| 383 |
+
"with torch.no_grad():\n",
|
| 384 |
+
" for images, labels in tqdm(testloader):\n",
|
| 385 |
+
" inputs, labels = images.to(device), labels\n",
|
| 386 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 387 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 388 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 389 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 390 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 391 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 392 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 393 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 394 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 395 |
+
" total += labels[0].size(0)\n",
|
| 396 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 397 |
+
" \n",
|
| 398 |
+
"# Calculate training accuracy after each epoch\n",
|
| 399 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 400 |
+
"print(\"===========================\")\n",
|
| 401 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 402 |
+
"print(\"===========================\")\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"import pickle\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"# Pickle the dictionary to a file\n",
|
| 407 |
+
"with open('models_8/test_1.pkl', 'wb') as f:\n",
|
| 408 |
+
" pickle.dump(results, f)\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"# Example binary labels\n",
|
| 413 |
+
"true = results['true'] # ground truth\n",
|
| 414 |
+
"pred = results['pred'] # predicted\n",
|
| 415 |
+
"\n",
|
| 416 |
+
"# Compute metrics\n",
|
| 417 |
+
"precision = precision_score(true, pred)\n",
|
| 418 |
+
"recall = recall_score(true, pred)\n",
|
| 419 |
+
"f1 = f1_score(true, pred)\n",
|
| 420 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 421 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"# Compute FPR\n",
|
| 424 |
+
"fpr = fp / (fp + tn)\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 429 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 430 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 431 |
+
]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"cell_type": "code",
|
| 435 |
+
"execution_count": 7,
|
| 436 |
+
"id": "8444ced5-686c-4c6e-bb02-8c4a45ff8af9",
|
| 437 |
+
"metadata": {},
|
| 438 |
+
"outputs": [
|
| 439 |
+
{
|
| 440 |
+
"name": "stdout",
|
| 441 |
+
"output_type": "stream",
|
| 442 |
+
"text": [
|
| 443 |
+
"99.04833333333333 0.09436925111261553\n",
|
| 444 |
+
"0.9936666666666666 0.002624669291337273\n",
|
| 445 |
+
"0.9903333333333334 0.0012472191289246482\n",
|
| 446 |
+
"0.006000000000000001 0.00282842712474619\n"
|
| 447 |
+
]
|
| 448 |
+
}
|
| 449 |
+
],
|
| 450 |
+
"source": [
|
| 451 |
+
"# acc\n",
|
| 452 |
+
"print(np.mean([98.94,99.17, 99.035 ]), np.std([98.94,99.17, 99.035 ]))\n",
|
| 453 |
+
"# recall\n",
|
| 454 |
+
"print(np.mean([0.996,0.988, 0.990]), np.std([0.996,0.995, 0.990]))\n",
|
| 455 |
+
"# f1\n",
|
| 456 |
+
"print(np.mean([0.989,0.992,0.990 ]),np.std([0.989,0.992,0.990 ]))\n",
|
| 457 |
+
"# fp\n",
|
| 458 |
+
"print(np.mean([0.004,0.004,0.010]),np.std([0.004,0.004,0.010]))\n"
|
| 459 |
+
]
|
| 460 |
+
}
|
| 461 |
+
],
|
| 462 |
+
"metadata": {
|
| 463 |
+
"kernelspec": {
|
| 464 |
+
"display_name": "Python 3 (ipykernel)",
|
| 465 |
+
"language": "python",
|
| 466 |
+
"name": "python3"
|
| 467 |
+
},
|
| 468 |
+
"language_info": {
|
| 469 |
+
"codemirror_mode": {
|
| 470 |
+
"name": "ipython",
|
| 471 |
+
"version": 3
|
| 472 |
+
},
|
| 473 |
+
"file_extension": ".py",
|
| 474 |
+
"mimetype": "text/x-python",
|
| 475 |
+
"name": "python",
|
| 476 |
+
"nbconvert_exporter": "python",
|
| 477 |
+
"pygments_lexer": "ipython3",
|
| 478 |
+
"version": "3.11.9"
|
| 479 |
+
}
|
| 480 |
+
},
|
| 481 |
+
"nbformat": 4,
|
| 482 |
+
"nbformat_minor": 5
|
| 483 |
+
}
|
models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"6\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
" 8%|███▋ | 4/48 [02:22<25:11, 34.35s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7efcbc3f67d0>>\n",
|
| 25 |
+
"Traceback (most recent call last):\n",
|
| 26 |
+
" File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n",
|
| 27 |
+
" def _clean_thread_parent_frames(\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"KeyboardInterrupt: \n",
|
| 30 |
+
" 10%|████▌ | 5/48 [02:53<24:56, 34.79s/it]\n"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"ename": "RuntimeError",
|
| 35 |
+
"evalue": "DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly",
|
| 36 |
+
"output_type": "error",
|
| 37 |
+
"traceback": [
|
| 38 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 39 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 40 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
|
| 41 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _ForkingPickler\u001b[38;5;241m.\u001b[39mloads(res)\n",
|
| 42 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/multiprocessing/reductions.py:496\u001b[0m, in \u001b[0;36mrebuild_storage_fd\u001b[0;34m(cls, df, size)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrebuild_storage_fd\u001b[39m(\u001b[38;5;28mcls\u001b[39m, df, size):\n\u001b[0;32m--> 496\u001b[0m fd \u001b[38;5;241m=\u001b[39m df\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
|
| 43 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:57\u001b[0m, in \u001b[0;36mDupFd.detach\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m'''Get the fd. This should only be called once.'''\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _resource_sharer\u001b[38;5;241m.\u001b[39mget_connection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_id) \u001b[38;5;28;01mas\u001b[39;00m conn:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m reduction\u001b[38;5;241m.\u001b[39mrecv_handle(conn)\n",
|
| 44 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:86\u001b[0m, in \u001b[0;36m_ResourceSharer.get_connection\u001b[0;34m(ident)\u001b[0m\n\u001b[1;32m 85\u001b[0m address, key \u001b[38;5;241m=\u001b[39m ident\n\u001b[0;32m---> 86\u001b[0m c \u001b[38;5;241m=\u001b[39m Client(address, authkey\u001b[38;5;241m=\u001b[39mprocess\u001b[38;5;241m.\u001b[39mcurrent_process()\u001b[38;5;241m.\u001b[39mauthkey)\n\u001b[1;32m 87\u001b[0m c\u001b[38;5;241m.\u001b[39msend((key, os\u001b[38;5;241m.\u001b[39mgetpid()))\n",
|
| 45 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:519\u001b[0m, in \u001b[0;36mClient\u001b[0;34m(address, family, authkey)\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 519\u001b[0m c \u001b[38;5;241m=\u001b[39m SocketClient(address)\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m authkey \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(authkey, \u001b[38;5;28mbytes\u001b[39m):\n",
|
| 46 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:647\u001b[0m, in \u001b[0;36mSocketClient\u001b[0;34m(address)\u001b[0m\n\u001b[1;32m 646\u001b[0m s\u001b[38;5;241m.\u001b[39msetblocking(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 647\u001b[0m s\u001b[38;5;241m.\u001b[39mconnect(address)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Connection(s\u001b[38;5;241m.\u001b[39mdetach())\n",
|
| 47 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory",
|
| 48 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
| 49 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 50 |
+
"Cell \u001b[0;32mIn[1], line 48\u001b[0m\n\u001b[1;32m 46\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[1;32m 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
| 51 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
|
| 52 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_data()\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
|
| 53 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_data()\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
|
| 54 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_try_get_data()\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
|
| 55 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1144\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1143\u001b[0m pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[0;32m-> 1144\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1145\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
|
| 56 |
+
"\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly"
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
],
|
| 60 |
+
"source": [
|
| 61 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 62 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 63 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 64 |
+
"from tqdm import tqdm\n",
|
| 65 |
+
"import torch\n",
|
| 66 |
+
"import numpy as np\n",
|
| 67 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 68 |
+
"import torch\n",
|
| 69 |
+
"import torch.nn as nn\n",
|
| 70 |
+
"import torch.optim as optim\n",
|
| 71 |
+
"from tqdm import tqdm \n",
|
| 72 |
+
"import torch.nn.functional as F\n",
|
| 73 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 74 |
+
"import pickle\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"torch.manual_seed(1)\n",
|
| 77 |
+
"# torch.manual_seed(42)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 81 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 82 |
+
"print(num_gpus)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 85 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"num_classes = 2\n",
|
| 88 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 91 |
+
"model = nn.DataParallel(model)\n",
|
| 92 |
+
"model = model.to(device)\n",
|
| 93 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 94 |
+
"print(\"num params \",params)\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
| 97 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 98 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 99 |
+
"model = model.eval()\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"# eval\n",
|
| 102 |
+
"val_loss = 0.0\n",
|
| 103 |
+
"correct_valid = 0\n",
|
| 104 |
+
"total = 0\n",
|
| 105 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 106 |
+
"model.eval()\n",
|
| 107 |
+
"with torch.no_grad():\n",
|
| 108 |
+
" for images, labels in tqdm(testloader):\n",
|
| 109 |
+
" inputs, labels = images.to(device), labels\n",
|
| 110 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 111 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 112 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 113 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 114 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 115 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 116 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 117 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 118 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 119 |
+
" total += labels[0].size(0)\n",
|
| 120 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 121 |
+
" \n",
|
| 122 |
+
"# Calculate training accuracy after each epoch\n",
|
| 123 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 124 |
+
"print(\"===========================\")\n",
|
| 125 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 126 |
+
"print(\"===========================\")\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"import pickle\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"# Pickle the dictionary to a file\n",
|
| 131 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
| 132 |
+
" pickle.dump(results, f)"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 143 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 144 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 145 |
+
"from tqdm import tqdm\n",
|
| 146 |
+
"import torch\n",
|
| 147 |
+
"import numpy as np\n",
|
| 148 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 149 |
+
"import torch\n",
|
| 150 |
+
"import torch.nn as nn\n",
|
| 151 |
+
"import torch.optim as optim\n",
|
| 152 |
+
"from tqdm import tqdm \n",
|
| 153 |
+
"import torch.nn.functional as F\n",
|
| 154 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 155 |
+
"import pickle\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"torch.manual_seed(1)\n",
|
| 158 |
+
"# torch.manual_seed(42)\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 162 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 163 |
+
"print(num_gpus)\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 166 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"num_classes = 2\n",
|
| 169 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 172 |
+
"model = nn.DataParallel(model)\n",
|
| 173 |
+
"model = model.to(device)\n",
|
| 174 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 175 |
+
"print(\"num params \",params)\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
| 179 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 180 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 181 |
+
"model = model.eval()\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"# eval\n",
|
| 184 |
+
"val_loss = 0.0\n",
|
| 185 |
+
"correct_valid = 0\n",
|
| 186 |
+
"total = 0\n",
|
| 187 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 188 |
+
"model.eval()\n",
|
| 189 |
+
"with torch.no_grad():\n",
|
| 190 |
+
" for images, labels in tqdm(testloader):\n",
|
| 191 |
+
" inputs, labels = images.to(device), labels\n",
|
| 192 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 193 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 194 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 195 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 196 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 197 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 198 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 199 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 200 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 201 |
+
" total += labels[0].size(0)\n",
|
| 202 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 203 |
+
" \n",
|
| 204 |
+
" \n",
|
| 205 |
+
"# Calculate training accuracy after each epoch\n",
|
| 206 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 207 |
+
"print(\"===========================\")\n",
|
| 208 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 209 |
+
"print(\"===========================\")\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"import pickle\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"# Pickle the dictionary to a file\n",
|
| 214 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
| 215 |
+
" pickle.dump(results, f)"
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"cell_type": "code",
|
| 220 |
+
"execution_count": null,
|
| 221 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"outputs": [],
|
| 224 |
+
"source": [
|
| 225 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 226 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 227 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 228 |
+
"from tqdm import tqdm\n",
|
| 229 |
+
"import torch\n",
|
| 230 |
+
"import numpy as np\n",
|
| 231 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 232 |
+
"import torch\n",
|
| 233 |
+
"import torch.nn as nn\n",
|
| 234 |
+
"import torch.optim as optim\n",
|
| 235 |
+
"from tqdm import tqdm \n",
|
| 236 |
+
"import torch.nn.functional as F\n",
|
| 237 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 238 |
+
"import pickle\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"torch.manual_seed(1)\n",
|
| 241 |
+
"# torch.manual_seed(42)\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 245 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 246 |
+
"print(num_gpus)\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 249 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"num_classes = 2\n",
|
| 252 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 255 |
+
"model = nn.DataParallel(model)\n",
|
| 256 |
+
"model = model.to(device)\n",
|
| 257 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 258 |
+
"print(\"num params \",params)\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
| 262 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 263 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 264 |
+
"model = model.eval()\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"# eval\n",
|
| 267 |
+
"val_loss = 0.0\n",
|
| 268 |
+
"correct_valid = 0\n",
|
| 269 |
+
"total = 0\n",
|
| 270 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 271 |
+
"model.eval()\n",
|
| 272 |
+
"with torch.no_grad():\n",
|
| 273 |
+
" for images, labels in tqdm(testloader):\n",
|
| 274 |
+
" inputs, labels = images.to(device), labels\n",
|
| 275 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 276 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 277 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 278 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 279 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 280 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 281 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 282 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 283 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 284 |
+
" total += labels[0].size(0)\n",
|
| 285 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 286 |
+
" \n",
|
| 287 |
+
" \n",
|
| 288 |
+
"# Calculate training accuracy after each epoch\n",
|
| 289 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 290 |
+
"print(\"===========================\")\n",
|
| 291 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 292 |
+
"print(\"===========================\")\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"import pickle\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"# Pickle the dictionary to a file\n",
|
| 297 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
| 298 |
+
" pickle.dump(results, f)"
|
| 299 |
+
]
|
| 300 |
+
}
|
| 301 |
+
],
|
| 302 |
+
"metadata": {
|
| 303 |
+
"kernelspec": {
|
| 304 |
+
"display_name": "Python 3 (ipykernel)",
|
| 305 |
+
"language": "python",
|
| 306 |
+
"name": "python3"
|
| 307 |
+
},
|
| 308 |
+
"language_info": {
|
| 309 |
+
"codemirror_mode": {
|
| 310 |
+
"name": "ipython",
|
| 311 |
+
"version": 3
|
| 312 |
+
},
|
| 313 |
+
"file_extension": ".py",
|
| 314 |
+
"mimetype": "text/x-python",
|
| 315 |
+
"name": "python",
|
| 316 |
+
"nbconvert_exporter": "python",
|
| 317 |
+
"pygments_lexer": "ipython3",
|
| 318 |
+
"version": "3.11.9"
|
| 319 |
+
}
|
| 320 |
+
},
|
| 321 |
+
"nbformat": 4,
|
| 322 |
+
"nbformat_minor": 5
|
| 323 |
+
}
|
models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 9,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"2\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:45<00:00, 1.07it/s]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"===========================\n",
|
| 32 |
+
"accuracy: 99.125\n",
|
| 33 |
+
"===========================\n",
|
| 34 |
+
"Precision: 0.992\n",
|
| 35 |
+
"Recall: 0.991\n",
|
| 36 |
+
"F1 Score: 0.991\n"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"name": "stderr",
|
| 41 |
+
"output_type": "stream",
|
| 42 |
+
"text": [
|
| 43 |
+
"\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 49 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 50 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 51 |
+
"from tqdm import tqdm\n",
|
| 52 |
+
"import torch\n",
|
| 53 |
+
"import numpy as np\n",
|
| 54 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 55 |
+
"import torch\n",
|
| 56 |
+
"import torch.nn as nn\n",
|
| 57 |
+
"import torch.optim as optim\n",
|
| 58 |
+
"from tqdm import tqdm \n",
|
| 59 |
+
"import torch.nn.functional as F\n",
|
| 60 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 61 |
+
"import pickle\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"torch.manual_seed(1)\n",
|
| 64 |
+
"# torch.manual_seed(42)\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 68 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 69 |
+
"print(num_gpus)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 72 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"num_classes = 2\n",
|
| 75 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 78 |
+
"model = nn.DataParallel(model)\n",
|
| 79 |
+
"model = model.to(device)\n",
|
| 80 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 81 |
+
"print(\"num params \",params)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
| 84 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 85 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 86 |
+
"model = model.eval()\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"# eval\n",
|
| 89 |
+
"val_loss = 0.0\n",
|
| 90 |
+
"correct_valid = 0\n",
|
| 91 |
+
"total = 0\n",
|
| 92 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 93 |
+
"model.eval()\n",
|
| 94 |
+
"with torch.no_grad():\n",
|
| 95 |
+
" for images, labels in tqdm(testloader):\n",
|
| 96 |
+
" inputs, labels = images.to(device), labels\n",
|
| 97 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 98 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 99 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 100 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 101 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 102 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 103 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 104 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 105 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 106 |
+
" total += labels[0].size(0)\n",
|
| 107 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 108 |
+
" \n",
|
| 109 |
+
"# Calculate training accuracy after each epoch\n",
|
| 110 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 111 |
+
"print(\"===========================\")\n",
|
| 112 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 113 |
+
"print(\"===========================\")\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"import pickle\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# Pickle the dictionary to a file\n",
|
| 118 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
| 119 |
+
" pickle.dump(results, f)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"# Example binary labels\n",
|
| 124 |
+
"true = results['true'] # ground truth\n",
|
| 125 |
+
"pred = results['pred'] # predicted\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"# Compute metrics\n",
|
| 128 |
+
"precision = precision_score(true, pred)\n",
|
| 129 |
+
"recall = recall_score(true, pred)\n",
|
| 130 |
+
"f1 = f1_score(true, pred)\n",
|
| 131 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 132 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"# Compute FPR\n",
|
| 135 |
+
"fpr = fp / (fp + tn)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 140 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 141 |
+
"print(f\"F1 Score: {f1:.3f}\")\n"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"cell_type": "code",
|
| 146 |
+
"execution_count": 10,
|
| 147 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"outputs": [
|
| 150 |
+
{
|
| 151 |
+
"name": "stdout",
|
| 152 |
+
"output_type": "stream",
|
| 153 |
+
"text": [
|
| 154 |
+
"2\n",
|
| 155 |
+
"num params encoder 50840\n",
|
| 156 |
+
"num params 21496282\n"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"name": "stderr",
|
| 161 |
+
"output_type": "stream",
|
| 162 |
+
"text": [
|
| 163 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 164 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 165 |
+
"100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"name": "stdout",
|
| 170 |
+
"output_type": "stream",
|
| 171 |
+
"text": [
|
| 172 |
+
"===========================\n",
|
| 173 |
+
"accuracy: 98.77\n",
|
| 174 |
+
"===========================\n",
|
| 175 |
+
"Precision: 0.982\n",
|
| 176 |
+
"Recall: 0.993\n",
|
| 177 |
+
"F1 Score: 0.988\n"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"name": "stderr",
|
| 182 |
+
"output_type": "stream",
|
| 183 |
+
"text": [
|
| 184 |
+
"\n"
|
| 185 |
+
]
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
"source": [
|
| 189 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 190 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 191 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 192 |
+
"from tqdm import tqdm\n",
|
| 193 |
+
"import torch\n",
|
| 194 |
+
"import numpy as np\n",
|
| 195 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 196 |
+
"import torch\n",
|
| 197 |
+
"import torch.nn as nn\n",
|
| 198 |
+
"import torch.optim as optim\n",
|
| 199 |
+
"from tqdm import tqdm \n",
|
| 200 |
+
"import torch.nn.functional as F\n",
|
| 201 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 202 |
+
"import pickle\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"torch.manual_seed(1)\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 208 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 209 |
+
"print(num_gpus)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 212 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"num_classes = 2\n",
|
| 215 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 218 |
+
"model = nn.DataParallel(model)\n",
|
| 219 |
+
"model = model.to(device)\n",
|
| 220 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 221 |
+
"print(\"num params \",params)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
| 225 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 226 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 227 |
+
"model = model.eval()\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"# eval\n",
|
| 230 |
+
"val_loss = 0.0\n",
|
| 231 |
+
"correct_valid = 0\n",
|
| 232 |
+
"total = 0\n",
|
| 233 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 234 |
+
"model.eval()\n",
|
| 235 |
+
"with torch.no_grad():\n",
|
| 236 |
+
" for images, labels in tqdm(testloader):\n",
|
| 237 |
+
" inputs, labels = images.to(device), labels\n",
|
| 238 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 239 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 240 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 241 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 242 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 243 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 244 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 245 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 246 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 247 |
+
" total += labels[0].size(0)\n",
|
| 248 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 249 |
+
" \n",
|
| 250 |
+
" \n",
|
| 251 |
+
"# Calculate training accuracy after each epoch\n",
|
| 252 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 253 |
+
"print(\"===========================\")\n",
|
| 254 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 255 |
+
"print(\"===========================\")\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"import pickle\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"# Pickle the dictionary to a file\n",
|
| 260 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
| 261 |
+
" pickle.dump(results, f)\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# Example binary labels\n",
|
| 266 |
+
"true = results['true'] # ground truth\n",
|
| 267 |
+
"pred = results['pred'] # predicted\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"# Compute metrics\n",
|
| 270 |
+
"precision = precision_score(true, pred)\n",
|
| 271 |
+
"recall = recall_score(true, pred)\n",
|
| 272 |
+
"f1 = f1_score(true, pred)\n",
|
| 273 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 274 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"# Compute FPR\n",
|
| 277 |
+
"fpr = fp / (fp + tn)\n",
|
| 278 |
+
"\n",
|
| 279 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 282 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 283 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "code",
|
| 288 |
+
"execution_count": 11,
|
| 289 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 290 |
+
"metadata": {},
|
| 291 |
+
"outputs": [
|
| 292 |
+
{
|
| 293 |
+
"name": "stdout",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"2\n",
|
| 297 |
+
"num params encoder 50840\n",
|
| 298 |
+
"num params 21496282\n"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"name": "stderr",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 306 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 307 |
+
"100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"name": "stdout",
|
| 312 |
+
"output_type": "stream",
|
| 313 |
+
"text": [
|
| 314 |
+
"===========================\n",
|
| 315 |
+
"accuracy: 99.03\n",
|
| 316 |
+
"===========================\n",
|
| 317 |
+
"Precision: 0.990\n",
|
| 318 |
+
"Recall: 0.990\n",
|
| 319 |
+
"F1 Score: 0.990\n"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"name": "stderr",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
"\n"
|
| 327 |
+
]
|
| 328 |
+
}
|
| 329 |
+
],
|
| 330 |
+
"source": [
|
| 331 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 332 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 333 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 334 |
+
"from tqdm import tqdm\n",
|
| 335 |
+
"import torch\n",
|
| 336 |
+
"import numpy as np\n",
|
| 337 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 338 |
+
"import torch\n",
|
| 339 |
+
"import torch.nn as nn\n",
|
| 340 |
+
"import torch.optim as optim\n",
|
| 341 |
+
"from tqdm import tqdm \n",
|
| 342 |
+
"import torch.nn.functional as F\n",
|
| 343 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 344 |
+
"import pickle\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"torch.manual_seed(1)\n",
|
| 347 |
+
"# torch.manual_seed(42)\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 351 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 352 |
+
"print(num_gpus)\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 355 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"num_classes = 2\n",
|
| 358 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 361 |
+
"model = nn.DataParallel(model)\n",
|
| 362 |
+
"model = model.to(device)\n",
|
| 363 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 364 |
+
"print(\"num params \",params)\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
| 368 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 369 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 370 |
+
"model = model.eval()\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"# eval\n",
|
| 373 |
+
"val_loss = 0.0\n",
|
| 374 |
+
"correct_valid = 0\n",
|
| 375 |
+
"total = 0\n",
|
| 376 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 377 |
+
"model.eval()\n",
|
| 378 |
+
"with torch.no_grad():\n",
|
| 379 |
+
" for images, labels in tqdm(testloader):\n",
|
| 380 |
+
" inputs, labels = images.to(device), labels\n",
|
| 381 |
+
" outputs = model(inputs, return_mask = True)\n",
|
| 382 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 383 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
| 384 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 385 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 386 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 387 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 388 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 389 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 390 |
+
" total += labels[0].size(0)\n",
|
| 391 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 392 |
+
" \n",
|
| 393 |
+
" \n",
|
| 394 |
+
"# Calculate training accuracy after each epoch\n",
|
| 395 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 396 |
+
"print(\"===========================\")\n",
|
| 397 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 398 |
+
"print(\"===========================\")\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"import pickle\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"# Pickle the dictionary to a file\n",
|
| 403 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
| 404 |
+
" pickle.dump(results, f)\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"# Example binary labels\n",
|
| 409 |
+
"true = results['true'] # ground truth\n",
|
| 410 |
+
"pred = results['pred'] # predicted\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"# Compute metrics\n",
|
| 413 |
+
"precision = precision_score(true, pred)\n",
|
| 414 |
+
"recall = recall_score(true, pred)\n",
|
| 415 |
+
"f1 = f1_score(true, pred)\n",
|
| 416 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 417 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"# Compute FPR\n",
|
| 420 |
+
"fpr = fp / (fp + tn)\n",
|
| 421 |
+
"\n",
|
| 422 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 425 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 426 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 427 |
+
]
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"cell_type": "code",
|
| 431 |
+
"execution_count": 17,
|
| 432 |
+
"id": "974e62d6-5088-4cd8-9721-6702717eadee",
|
| 433 |
+
"metadata": {},
|
| 434 |
+
"outputs": [
|
| 435 |
+
{
|
| 436 |
+
"name": "stdout",
|
| 437 |
+
"output_type": "stream",
|
| 438 |
+
"text": [
|
| 439 |
+
"98.97499999999998 0.1500555452713003\n",
|
| 440 |
+
"0.9913333333333333 0.0012472191289246482\n",
|
| 441 |
+
"0.9896666666666666 0.0012472191289246482\n"
|
| 442 |
+
]
|
| 443 |
+
}
|
| 444 |
+
],
|
| 445 |
+
"source": [
|
| 446 |
+
"# acc\n",
|
| 447 |
+
"print(np.mean([99.125,98.77, 99.03 ]), np.std([99.125,98.77, 99.03 ]))\n",
|
| 448 |
+
"# precision\n",
|
| 449 |
+
"print(np.mean([0.991,0.990, 0.993]), np.std([0.991,0.990, 0.993]))\n",
|
| 450 |
+
"# f1\n",
|
| 451 |
+
"print(np.mean([0.990,0.988,0.991 ]),np.std([0.990,0.988,0.991 ]))\n",
|
| 452 |
+
"# recall"
|
| 453 |
+
]
|
| 454 |
+
},
|
| 455 |
+
{
|
| 456 |
+
"cell_type": "code",
|
| 457 |
+
"execution_count": 18,
|
| 458 |
+
"id": "3eee97ad-114f-4090-b54a-6ec0cc7150f5",
|
| 459 |
+
"metadata": {},
|
| 460 |
+
"outputs": [
|
| 461 |
+
{
|
| 462 |
+
"name": "stdout",
|
| 463 |
+
"output_type": "stream",
|
| 464 |
+
"text": [
|
| 465 |
+
"False Positive Rate: 0.200\n"
|
| 466 |
+
]
|
| 467 |
+
}
|
| 468 |
+
],
|
| 469 |
+
"source": [
|
| 470 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"# Ground truth and predictions\n",
|
| 473 |
+
"true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 0]\n",
|
| 474 |
+
"pred = [1, 0, 1, 0, 0, 1, 1, 0, 1, 0]\n",
|
| 475 |
+
"\n"
|
| 476 |
+
]
|
| 477 |
+
}
|
| 478 |
+
],
|
| 479 |
+
"metadata": {
|
| 480 |
+
"kernelspec": {
|
| 481 |
+
"display_name": "Python 3 (ipykernel)",
|
| 482 |
+
"language": "python",
|
| 483 |
+
"name": "python3"
|
| 484 |
+
},
|
| 485 |
+
"language_info": {
|
| 486 |
+
"codemirror_mode": {
|
| 487 |
+
"name": "ipython",
|
| 488 |
+
"version": 3
|
| 489 |
+
},
|
| 490 |
+
"file_extension": ".py",
|
| 491 |
+
"mimetype": "text/x-python",
|
| 492 |
+
"name": "python",
|
| 493 |
+
"nbconvert_exporter": "python",
|
| 494 |
+
"pygments_lexer": "ipython3",
|
| 495 |
+
"version": "3.11.9"
|
| 496 |
+
}
|
| 497 |
+
},
|
| 498 |
+
"nbformat": 4,
|
| 499 |
+
"nbformat_minor": 5
|
| 500 |
+
}
|
models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"2\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stderr",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"name": "stdout",
|
| 28 |
+
"output_type": "stream",
|
| 29 |
+
"text": [
|
| 30 |
+
"===========================\n",
|
| 31 |
+
"accuracy: 99.195\n",
|
| 32 |
+
"===========================\n",
|
| 33 |
+
"False Positive Rate: 0.005\n",
|
| 34 |
+
"Precision: 0.995\n",
|
| 35 |
+
"Recall: 0.989\n",
|
| 36 |
+
"F1 Score: 0.992\n"
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"source": [
|
| 41 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 42 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 43 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 44 |
+
"from tqdm import tqdm\n",
|
| 45 |
+
"import torch\n",
|
| 46 |
+
"import numpy as np\n",
|
| 47 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 48 |
+
"import torch\n",
|
| 49 |
+
"import torch.nn as nn\n",
|
| 50 |
+
"import torch.optim as optim\n",
|
| 51 |
+
"from tqdm import tqdm \n",
|
| 52 |
+
"import torch.nn.functional as F\n",
|
| 53 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 54 |
+
"import pickle\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"torch.manual_seed(1)\n",
|
| 57 |
+
"# torch.manual_seed(42)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 61 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 62 |
+
"print(num_gpus)\n",
|
| 63 |
+
"threshold = 0.992\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 66 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"num_classes = 2\n",
|
| 69 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 72 |
+
"model = nn.DataParallel(model)\n",
|
| 73 |
+
"model = model.to(device)\n",
|
| 74 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 75 |
+
"print(\"num params \",params)\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
| 78 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 79 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 80 |
+
"model = model.eval()\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# eval\n",
|
| 83 |
+
"val_loss = 0.0\n",
|
| 84 |
+
"correct_valid = 0\n",
|
| 85 |
+
"total = 0\n",
|
| 86 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 87 |
+
"model.eval()\n",
|
| 88 |
+
"with torch.no_grad():\n",
|
| 89 |
+
" for images, labels in testloader:\n",
|
| 90 |
+
" inputs, labels = images.to(device), labels\n",
|
| 91 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
| 92 |
+
" selection = outputs[:, 1] > threshold\n",
|
| 93 |
+
" predicted = selection.int()\n",
|
| 94 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 95 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 96 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 97 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 98 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 99 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 100 |
+
" total += labels[0].size(0)\n",
|
| 101 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 102 |
+
" \n",
|
| 103 |
+
"# Calculate training accuracy after each epoch\n",
|
| 104 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 105 |
+
"print(\"===========================\")\n",
|
| 106 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 107 |
+
"print(\"===========================\")\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"import pickle\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Pickle the dictionary to a file\n",
|
| 112 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
| 113 |
+
" pickle.dump(results, f)\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 116 |
+
"from sklearn.metrics import confusion_matrix\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"# Example binary labels\n",
|
| 119 |
+
"true = results['true'] # ground truth\n",
|
| 120 |
+
"pred = results['pred'] # predicted\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"# Compute metrics\n",
|
| 123 |
+
"precision = precision_score(true, pred)\n",
|
| 124 |
+
"recall = recall_score(true, pred)\n",
|
| 125 |
+
"f1 = f1_score(true, pred)\n",
|
| 126 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 127 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"# Compute FPR\n",
|
| 130 |
+
"fpr = fp / (fp + tn)\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 135 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 136 |
+
"print(f\"F1 Score: {f1:.3f}\")\n"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": 2,
|
| 142 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [
|
| 145 |
+
{
|
| 146 |
+
"name": "stdout",
|
| 147 |
+
"output_type": "stream",
|
| 148 |
+
"text": [
|
| 149 |
+
"2\n",
|
| 150 |
+
"num params encoder 50840\n",
|
| 151 |
+
"num params 21496282\n"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"name": "stderr",
|
| 156 |
+
"output_type": "stream",
|
| 157 |
+
"text": [
|
| 158 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 159 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"name": "stdout",
|
| 164 |
+
"output_type": "stream",
|
| 165 |
+
"text": [
|
| 166 |
+
"===========================\n",
|
| 167 |
+
"accuracy: 99.195\n",
|
| 168 |
+
"===========================\n",
|
| 169 |
+
"False Positive Rate: 0.007\n",
|
| 170 |
+
"Precision: 0.993\n",
|
| 171 |
+
"Recall: 0.991\n",
|
| 172 |
+
"F1 Score: 0.992\n"
|
| 173 |
+
]
|
| 174 |
+
}
|
| 175 |
+
],
|
| 176 |
+
"source": [
|
| 177 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 178 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 179 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 180 |
+
"from tqdm import tqdm\n",
|
| 181 |
+
"import torch\n",
|
| 182 |
+
"import numpy as np\n",
|
| 183 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 184 |
+
"import torch\n",
|
| 185 |
+
"import torch.nn as nn\n",
|
| 186 |
+
"import torch.optim as optim\n",
|
| 187 |
+
"from tqdm import tqdm \n",
|
| 188 |
+
"import torch.nn.functional as F\n",
|
| 189 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 190 |
+
"import pickle\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"torch.manual_seed(1)\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 196 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 197 |
+
"print(num_gpus)\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 200 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"num_classes = 2\n",
|
| 203 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 206 |
+
"model = nn.DataParallel(model)\n",
|
| 207 |
+
"model = model.to(device)\n",
|
| 208 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 209 |
+
"print(\"num params \",params)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
| 213 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 214 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 215 |
+
"model = model.eval()\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"# eval\n",
|
| 218 |
+
"val_loss = 0.0\n",
|
| 219 |
+
"correct_valid = 0\n",
|
| 220 |
+
"total = 0\n",
|
| 221 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 222 |
+
"model.eval()\n",
|
| 223 |
+
"with torch.no_grad():\n",
|
| 224 |
+
" for images, labels in testloader:\n",
|
| 225 |
+
" inputs, labels = images.to(device), labels\n",
|
| 226 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
| 227 |
+
" selection = outputs[:, 1] > threshold\n",
|
| 228 |
+
" predicted = selection.int()\n",
|
| 229 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 230 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 231 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 232 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 233 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 234 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 235 |
+
" total += labels[0].size(0)\n",
|
| 236 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 237 |
+
" \n",
|
| 238 |
+
" \n",
|
| 239 |
+
"# Calculate training accuracy after each epoch\n",
|
| 240 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 241 |
+
"print(\"===========================\")\n",
|
| 242 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 243 |
+
"print(\"===========================\")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"import pickle\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"# Pickle the dictionary to a file\n",
|
| 248 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
| 249 |
+
" pickle.dump(results, f)\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"# Example binary labels\n",
|
| 254 |
+
"true = results['true'] # ground truth\n",
|
| 255 |
+
"pred = results['pred'] # predicted\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"# Compute metrics\n",
|
| 258 |
+
"precision = precision_score(true, pred)\n",
|
| 259 |
+
"recall = recall_score(true, pred)\n",
|
| 260 |
+
"f1 = f1_score(true, pred)\n",
|
| 261 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 262 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"# Compute FPR\n",
|
| 265 |
+
"fpr = fp / (fp + tn)\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 270 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 271 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": 3,
|
| 277 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
| 278 |
+
"metadata": {},
|
| 279 |
+
"outputs": [
|
| 280 |
+
{
|
| 281 |
+
"name": "stdout",
|
| 282 |
+
"output_type": "stream",
|
| 283 |
+
"text": [
|
| 284 |
+
"2\n",
|
| 285 |
+
"num params encoder 50840\n",
|
| 286 |
+
"num params 21496282\n"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"name": "stderr",
|
| 291 |
+
"output_type": "stream",
|
| 292 |
+
"text": [
|
| 293 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 294 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"name": "stdout",
|
| 299 |
+
"output_type": "stream",
|
| 300 |
+
"text": [
|
| 301 |
+
"===========================\n",
|
| 302 |
+
"accuracy: 99.035\n",
|
| 303 |
+
"===========================\n",
|
| 304 |
+
"False Positive Rate: 0.007\n",
|
| 305 |
+
"Precision: 0.993\n",
|
| 306 |
+
"Recall: 0.987\n",
|
| 307 |
+
"F1 Score: 0.990\n"
|
| 308 |
+
]
|
| 309 |
+
}
|
| 310 |
+
],
|
| 311 |
+
"source": [
|
| 312 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 313 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 314 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
| 315 |
+
"from tqdm import tqdm\n",
|
| 316 |
+
"import torch\n",
|
| 317 |
+
"import numpy as np\n",
|
| 318 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 319 |
+
"import torch\n",
|
| 320 |
+
"import torch.nn as nn\n",
|
| 321 |
+
"import torch.optim as optim\n",
|
| 322 |
+
"from tqdm import tqdm \n",
|
| 323 |
+
"import torch.nn.functional as F\n",
|
| 324 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 325 |
+
"import pickle\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"torch.manual_seed(1)\n",
|
| 328 |
+
"# torch.manual_seed(42)\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 332 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 333 |
+
"print(num_gpus)\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
| 336 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"num_classes = 2\n",
|
| 339 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 342 |
+
"model = nn.DataParallel(model)\n",
|
| 343 |
+
"model = model.to(device)\n",
|
| 344 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 345 |
+
"print(\"num params \",params)\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
| 349 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
| 350 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
| 351 |
+
"model = model.eval()\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"# eval\n",
|
| 354 |
+
"val_loss = 0.0\n",
|
| 355 |
+
"correct_valid = 0\n",
|
| 356 |
+
"total = 0\n",
|
| 357 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
| 358 |
+
"model.eval()\n",
|
| 359 |
+
"with torch.no_grad():\n",
|
| 360 |
+
" for images, labels in testloader:\n",
|
| 361 |
+
" inputs, labels = images.to(device), labels\n",
|
| 362 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
| 363 |
+
" selection = outputs[:, 1] > threshold\n",
|
| 364 |
+
" predicted = selection.int()\n",
|
| 365 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
| 366 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
| 367 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
| 368 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
| 369 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
| 370 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
| 371 |
+
" total += labels[0].size(0)\n",
|
| 372 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
| 373 |
+
" \n",
|
| 374 |
+
" \n",
|
| 375 |
+
"# Calculate training accuracy after each epoch\n",
|
| 376 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
| 377 |
+
"print(\"===========================\")\n",
|
| 378 |
+
"print('accuracy: ', val_accuracy)\n",
|
| 379 |
+
"print(\"===========================\")\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"import pickle\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"# Pickle the dictionary to a file\n",
|
| 384 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
| 385 |
+
" pickle.dump(results, f)\n",
|
| 386 |
+
"\n",
|
| 387 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"# Example binary labels\n",
|
| 390 |
+
"true = results['true'] # ground truth\n",
|
| 391 |
+
"pred = results['pred'] # predicted\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"# Compute metrics\n",
|
| 394 |
+
"precision = precision_score(true, pred)\n",
|
| 395 |
+
"recall = recall_score(true, pred)\n",
|
| 396 |
+
"f1 = f1_score(true, pred)\n",
|
| 397 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
| 398 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"# Compute FPR\n",
|
| 401 |
+
"fpr = fp / (fp + tn)\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
| 406 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
| 407 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"cell_type": "code",
|
| 412 |
+
"execution_count": 6,
|
| 413 |
+
"id": "974e62d6-5088-4cd8-9721-6702717eadee",
|
| 414 |
+
"metadata": {},
|
| 415 |
+
"outputs": [
|
| 416 |
+
{
|
| 417 |
+
"name": "stdout",
|
| 418 |
+
"output_type": "stream",
|
| 419 |
+
"text": [
|
| 420 |
+
"99.14166666666665 0.07542472332656346\n",
|
| 421 |
+
"0.9913333333333333 0.0012472191289246482\n",
|
| 422 |
+
"0.9913333333333334 0.0012472191289246482\n",
|
| 423 |
+
"0.006333333333333333 0.0009428090415820634\n"
|
| 424 |
+
]
|
| 425 |
+
}
|
| 426 |
+
],
|
| 427 |
+
"source": [
|
| 428 |
+
"# acc\n",
|
| 429 |
+
"print(np.mean([99.195,99.195, 99.035 ]), np.std([99.195,99.195, 99.035]))\n",
|
| 430 |
+
"# recall\n",
|
| 431 |
+
"print(np.mean([0.991,0.991, 0.987]), np.std([0.991,0.990, 0.993]))\n",
|
| 432 |
+
"# f1\n",
|
| 433 |
+
"print(np.mean([0.992,0.992,0.990 ]),np.std([0.990,0.988,0.991 ]))\n",
|
| 434 |
+
"# fp\n",
|
| 435 |
+
"print(np.mean([0.005,0.007,0.007 ]),np.std([0.005,0.007,0.007]))\n"
|
| 436 |
+
]
|
| 437 |
+
}
|
| 438 |
+
],
|
| 439 |
+
"metadata": {
|
| 440 |
+
"kernelspec": {
|
| 441 |
+
"display_name": "Python 3 (ipykernel)",
|
| 442 |
+
"language": "python",
|
| 443 |
+
"name": "python3"
|
| 444 |
+
},
|
| 445 |
+
"language_info": {
|
| 446 |
+
"codemirror_mode": {
|
| 447 |
+
"name": "ipython",
|
| 448 |
+
"version": 3
|
| 449 |
+
},
|
| 450 |
+
"file_extension": ".py",
|
| 451 |
+
"mimetype": "text/x-python",
|
| 452 |
+
"name": "python",
|
| 453 |
+
"nbconvert_exporter": "python",
|
| 454 |
+
"pygments_lexer": "ipython3",
|
| 455 |
+
"version": "3.11.9"
|
| 456 |
+
}
|
| 457 |
+
},
|
| 458 |
+
"nbformat": 4,
|
| 459 |
+
"nbformat_minor": 5
|
| 460 |
+
}
|
models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"6\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
| 21 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 22 |
+
"import torch\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"import torch.nn as nn\n",
|
| 27 |
+
"import torch.optim as optim\n",
|
| 28 |
+
"from tqdm import tqdm \n",
|
| 29 |
+
"import torch.nn.functional as F\n",
|
| 30 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 31 |
+
"import pickle\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"torch.manual_seed(1)\n",
|
| 34 |
+
"# torch.manual_seed(42)\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 38 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 39 |
+
"print(num_gpus)\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# Create custom dataset instance\n",
|
| 42 |
+
"# Create custom dataset instance\n",
|
| 43 |
+
"data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
|
| 44 |
+
"dataset = CustomDataset(data_dir, transform=transform)\n",
|
| 45 |
+
"valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
|
| 46 |
+
"valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"num_classes = 2\n",
|
| 50 |
+
"trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 53 |
+
"model = nn.DataParallel(model)\n",
|
| 54 |
+
"model = model.to(device)\n",
|
| 55 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 56 |
+
"print(\"num params \",params)\n"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 2,
|
| 62 |
+
"id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [
|
| 65 |
+
{
|
| 66 |
+
"name": "stderr",
|
| 67 |
+
"output_type": "stream",
|
| 68 |
+
"text": [
|
| 69 |
+
" 0%| | 0/477 [00:00<?, ?batch/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 70 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
| 71 |
+
"100%|██████████████████████████████████████| 477/477 [08:57<00:00, 1.13s/batch]\n"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"ename": "NameError",
|
| 76 |
+
"evalue": "name 'validloader' is not defined",
|
| 77 |
+
"output_type": "error",
|
| 78 |
+
"traceback": [
|
| 79 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 80 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
| 81 |
+
"Cell \u001b[0;32mIn[2], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m validloader:\n\u001b[1;32m 30\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\u001b[38;5;241m.\u001b[39mto(device)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 31\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
|
| 82 |
+
"\u001b[0;31mNameError\u001b[0m: name 'validloader' is not defined"
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
],
|
| 86 |
+
"source": [
|
| 87 |
+
"criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))\n",
|
| 88 |
+
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
|
| 89 |
+
"scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"for epoch in range(5):\n",
|
| 92 |
+
" running_loss = 0.0\n",
|
| 93 |
+
" correct_train = 0\n",
|
| 94 |
+
" total_train = 0\n",
|
| 95 |
+
" with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
|
| 96 |
+
" model.train()\n",
|
| 97 |
+
" for i, (images, labels) in enumerate(tepoch):\n",
|
| 98 |
+
" inputs, labels = images.to(device), labels.to(device).float()\n",
|
| 99 |
+
" optimizer.zero_grad()\n",
|
| 100 |
+
" outputs = model(inputs, return_mask=False).to(device)\n",
|
| 101 |
+
" new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)\n",
|
| 102 |
+
" loss = criterion(outputs, new_label)\n",
|
| 103 |
+
" loss.backward()\n",
|
| 104 |
+
" optimizer.step()\n",
|
| 105 |
+
" running_loss += loss.item()\n",
|
| 106 |
+
" # Calculate training accuracy\n",
|
| 107 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
| 108 |
+
" total_train += labels.size(0)\n",
|
| 109 |
+
" correct_train += (predicted == labels).sum().item() \n",
|
| 110 |
+
" val_loss = 0.0\n",
|
| 111 |
+
" correct_valid = 0\n",
|
| 112 |
+
" total = 0\n",
|
| 113 |
+
" model.eval()\n",
|
| 114 |
+
" with torch.no_grad():\n",
|
| 115 |
+
" for images, labels in validloader:\n",
|
| 116 |
+
" inputs, labels = images.to(device), labels.to(device).float()\n",
|
| 117 |
+
" optimizer.zero_grad()\n",
|
| 118 |
+
" outputs = model(inputs, return_mask=False)\n",
|
| 119 |
+
" new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)\n",
|
| 120 |
+
" loss = criterion(outputs, new_label)\n",
|
| 121 |
+
" val_loss += loss.item()\n",
|
| 122 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
| 123 |
+
" total += labels.size(0)\n",
|
| 124 |
+
" correct_valid += (predicted == labels).sum().item()\n",
|
| 125 |
+
" scheduler.step(val_loss)\n",
|
| 126 |
+
" # Calculate training accuracy after each epoch\n",
|
| 127 |
+
" train_accuracy = 100 * correct_train / total_train\n",
|
| 128 |
+
" val_accuracy = correct_valid / total * 100.0\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" print(\"===========================\")\n",
|
| 132 |
+
" print('accuracy: ', epoch, train_accuracy, val_accuracy)\n",
|
| 133 |
+
" print('learning rate: ', scheduler.get_last_lr())\n",
|
| 134 |
+
" print(\"===========================\")"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"id": "3faa4a11-89fb-4556-ae87-3645a47fa00d",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"train_accuracy = 100 * correct_train / total_train\n",
|
| 145 |
+
"print('accuracy: ', epoch, train_accuracy)"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"id": "e586c4d2-a7f4-4f14-81fc-4f84ffac52b3",
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"import sigpyproc.readers as r\n",
|
| 156 |
+
"import cv2\n",
|
| 157 |
+
"import numpy as np\n",
|
| 158 |
+
"import matplotlib.pyplot as plt\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"from scipy.special import softmax\n",
|
| 161 |
+
"%matplotlib inline\n",
|
| 162 |
+
"path = '/mnt/primary/ata/projects/p051/fil_60565_59210_9756774_J0534+2200_0001/LoB.C0928/fil_60565_59210_9756774_J0534+2200_0001-beam0000.fil'\n",
|
| 163 |
+
"# path = '/mnt/primary/ata/projects/p051/fil_60564_62428_4679748_J0332+5434_0001/LoB.C0928/fil_60564_62428_4679748_J0332+5434_0001-beam0000.fil'\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"# Get some metadata\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"# Open the filterbank file\n",
|
| 168 |
+
"fil = r.FilReader(path)\n",
|
| 169 |
+
"header = fil.header\n",
|
| 170 |
+
"print(\"Header:\", header)\n",
|
| 171 |
+
"n=100\n",
|
| 172 |
+
"li = [ 7257608, 7324207, 10393163, 10641071, 11130537, 11085081,\n",
|
| 173 |
+
" 11419145, 11964112, 12329364, 13047181]\n",
|
| 174 |
+
"for el in li:\n",
|
| 175 |
+
" data = torch.tensor(fil.read_block(el-1024, 2048)).cuda()\n",
|
| 176 |
+
" print(data.shape)\n",
|
| 177 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
| 178 |
+
" print(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
| 179 |
+
" plt.figure(figsize=(10,10))\n",
|
| 180 |
+
" plt.imshow(data.cpu().numpy(), aspect = 10)\n",
|
| 181 |
+
" plt.show()"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": null,
|
| 187 |
+
"id": "609e5564-f14f-4bd1-b604-68e7e7d42834",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"triggers = []\n",
|
| 192 |
+
"counter = 0\n",
|
| 193 |
+
"with torch.no_grad():\n",
|
| 194 |
+
" for i in range(2048,10201921, 2048 ):\n",
|
| 195 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
| 196 |
+
" # Shuffle the tensor using the random indices\n",
|
| 197 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
| 198 |
+
" triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
| 199 |
+
" counter += 1\n",
|
| 200 |
+
" if counter > 1000:\n",
|
| 201 |
+
" break"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"id": "08ee6dcf-cb30-4490-8624-4e52552fdf39",
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"outputs": [],
|
| 210 |
+
"source": [
|
| 211 |
+
"print(triggers[0])"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"id": "8c56c6f5-5a0b-4854-8a94-066a9baf4cfc",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"stack = np.stack(triggers)\n",
|
| 222 |
+
"positives = stack[:,0,1]\n",
|
| 223 |
+
"num_pos = np.where(positives > 0.5)[0].shape[0]\n",
|
| 224 |
+
"print(num_pos)"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"cell_type": "code",
|
| 229 |
+
"execution_count": null,
|
| 230 |
+
"id": "eb1d1591-8855-4989-bf12-c8a9cdbf2a4d",
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"outputs": [],
|
| 233 |
+
"source": [
|
| 234 |
+
"import pickle\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"# Path to your pickle file\n",
|
| 237 |
+
"file_path = \"../dataset_generator/dir.pkl\"\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"# Open and load the pickle file\n",
|
| 240 |
+
"with open(file_path, \"rb\") as file: # Use \"rb\" mode for reading binary files\n",
|
| 241 |
+
" data = pickle.load(file)\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"# Print the contents of the file\n"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "code",
|
| 248 |
+
"execution_count": null,
|
| 249 |
+
"id": "46f61d7e-55fa-44fe-be94-d4ddb3c576f9",
|
| 250 |
+
"metadata": {},
|
| 251 |
+
"outputs": [],
|
| 252 |
+
"source": [
|
| 253 |
+
"import sigpyproc.readers as r\n",
|
| 254 |
+
"import cv2\n",
|
| 255 |
+
"import numpy as np\n",
|
| 256 |
+
"import matplotlib.pyplot as plt\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"from scipy.special import softmax\n",
|
| 259 |
+
"%matplotlib inline\n",
|
| 260 |
+
"path = data[0]\n",
|
| 261 |
+
"model.eval()\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"fil = r.FilReader(path)\n",
|
| 264 |
+
"header = fil.header\n",
|
| 265 |
+
"print(\"Header:\", header)\n",
|
| 266 |
+
"n=100\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"triggers = []\n",
|
| 270 |
+
"counter = 0\n",
|
| 271 |
+
"for i in range(2048,10201921, 2048):\n",
|
| 272 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
| 273 |
+
" # Shuffle the tensor using the random indices\n",
|
| 274 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
| 275 |
+
" triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
| 276 |
+
" counter += 1\n",
|
| 277 |
+
" if counter > 1000:\n",
|
| 278 |
+
" break"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": null,
|
| 284 |
+
"id": "413d402e-2ce3-49fc-bbd4-a3cf1cc92388",
|
| 285 |
+
"metadata": {},
|
| 286 |
+
"outputs": [],
|
| 287 |
+
"source": [
|
| 288 |
+
"print(triggers[0])"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": null,
|
| 294 |
+
"id": "5c039dee-1b9b-4664-b42a-a79d780f37f1",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [],
|
| 297 |
+
"source": [
|
| 298 |
+
"stack = np.stack(triggers)\n",
|
| 299 |
+
"positives = stack[:,0,1]\n",
|
| 300 |
+
"num_pos = np.where(positives > 0.5)[0].shape[0]\n",
|
| 301 |
+
"print(num_pos)"
|
| 302 |
+
]
|
| 303 |
+
}
|
| 304 |
+
],
|
| 305 |
+
"metadata": {
|
| 306 |
+
"kernelspec": {
|
| 307 |
+
"display_name": "Python 3 (ipykernel)",
|
| 308 |
+
"language": "python",
|
| 309 |
+
"name": "python3"
|
| 310 |
+
},
|
| 311 |
+
"language_info": {
|
| 312 |
+
"codemirror_mode": {
|
| 313 |
+
"name": "ipython",
|
| 314 |
+
"version": 3
|
| 315 |
+
},
|
| 316 |
+
"file_extension": ".py",
|
| 317 |
+
"mimetype": "text/x-python",
|
| 318 |
+
"name": "python",
|
| 319 |
+
"nbconvert_exporter": "python",
|
| 320 |
+
"pygments_lexer": "ipython3",
|
| 321 |
+
"version": "3.11.9"
|
| 322 |
+
}
|
| 323 |
+
},
|
| 324 |
+
"nbformat": 4,
|
| 325 |
+
"nbformat_minor": 5
|
| 326 |
+
}
|
models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b063cb5ad71d38a551a5a4beb9ae21399e5e633af128c2608645398509854239
|
| 3 |
+
size 16982029
|
models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 8,
|
| 6 |
+
"id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"6\n",
|
| 14 |
+
"num params encoder 50840\n",
|
| 15 |
+
"num params 21496282\n"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
],
|
| 19 |
+
"source": [
|
| 20 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
| 21 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 22 |
+
"import torch\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"import torch.nn as nn\n",
|
| 27 |
+
"import torch.optim as optim\n",
|
| 28 |
+
"from tqdm import tqdm \n",
|
| 29 |
+
"import torch.nn.functional as F\n",
|
| 30 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
| 31 |
+
"import pickle\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"torch.manual_seed(1)\n",
|
| 34 |
+
"# torch.manual_seed(42)\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 38 |
+
"num_gpus = torch.cuda.device_count()\n",
|
| 39 |
+
"print(num_gpus)\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# Create custom dataset instance\n",
|
| 42 |
+
"data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
|
| 43 |
+
"dataset = CustomDataset(data_dir, transform=transform)\n",
|
| 44 |
+
"valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
|
| 45 |
+
"valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"num_classes = 2\n",
|
| 49 |
+
"trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
| 52 |
+
"model = nn.DataParallel(model)\n",
|
| 53 |
+
"model = model.to(device)\n",
|
| 54 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
| 55 |
+
"print(\"num params \",params)\n"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 9,
|
| 61 |
+
"id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"# model_path = 'models/model-62-98.78.pt'\n",
|
| 66 |
+
"# model_path = 'models/model-47-99.125.pt'\n",
|
| 67 |
+
"model_path = 'models_mask/model-37-99.175_42.pt'\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# model_path = 'models_mask/model-10-97.055_1.pt'\n",
|
| 70 |
+
"model.load_state_dict(torch.load(model_path, weights_only=True))\n",
|
| 71 |
+
"model = model.eval()"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 10,
|
| 77 |
+
"id": "89d108de-7eae-4bbd-837c-8e657082a1e6",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"Header(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', data_type='raw data', nchans=192, foff=-0.5, fch1=1187.5, nbits=32, tsamp=6.4e-05, tstart=60692.03208333333, nsamples=28125184, nifs=1, coord=<SkyCoord (ICRS): (ra, dec) in deg\n",
|
| 85 |
+
" (83.63322, 22.01446)>, azimuth=<Angle 80.54659271 deg>, zenith=<Angle 66.41192055 deg>, telescope='Effelsberg LOFAR', backend='FAKE', source='crab', frame='topocentric', ibeam=0, nbeams=2, dm=0, period=0, accel=0, signed=False, rawdatafile='', stream_info=StreamInfo(entries=[FileInfo(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', hdrlen=338, datalen=21600141312, nsamples=28125184, tstart=60692.03208333333, tsamp=6.4e-05)]))\n"
|
| 86 |
+
]
|
| 87 |
+
}
|
| 88 |
+
],
|
| 89 |
+
"source": [
|
| 90 |
+
"import sigpyproc.readers as r\n",
|
| 91 |
+
"import cv2\n",
|
| 92 |
+
"import numpy as np\n",
|
| 93 |
+
"import matplotlib.pyplot as plt\n",
|
| 94 |
+
"fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
|
| 95 |
+
"# fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoB.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
|
| 96 |
+
"header = fil.header\n",
|
| 97 |
+
"print(header)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 11,
|
| 103 |
+
"id": "0b276e6e-d6c8-41da-808d-542ee22133d1",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [
|
| 106 |
+
{
|
| 107 |
+
"name": "stderr",
|
| 108 |
+
"output_type": "stream",
|
| 109 |
+
"text": [
|
| 110 |
+
" 0%| | 0/13732 [00:00<?, ?it/s]/tmp/ipykernel_19961/1777549771.py:15: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
| 111 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
| 112 |
+
" 3%|▉ | 351/13732 [00:13<08:27, 26.38it/s]\n"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"ename": "KeyboardInterrupt",
|
| 117 |
+
"evalue": "",
|
| 118 |
+
"output_type": "error",
|
| 119 |
+
"traceback": [
|
| 120 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 121 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 122 |
+
"Cell \u001b[0;32mIn[11], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m counter \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2048\u001b[39m,header\u001b[38;5;241m.\u001b[39mnsamples, \u001b[38;5;241m2048\u001b[39m)):\n\u001b[0;32m---> 13\u001b[0m data \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(fil\u001b[38;5;241m.\u001b[39mread_block(i\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1024\u001b[39m, \u001b[38;5;241m2048\u001b[39m))\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# Shuffle the tensor using the random indices\u001b[39;00m\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(transform(torch\u001b[38;5;241m.\u001b[39mtensor(data)\u001b[38;5;241m.\u001b[39mcuda())[\u001b[38;5;28;01mNone\u001b[39;00m])\n",
|
| 123 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"name": "stdout",
|
| 128 |
+
"output_type": "stream",
|
| 129 |
+
"text": [
|
| 130 |
+
"Error in callback <function flush_figures at 0x7f6c8689ae80> (for post_execute), with arguments args (),kwargs {}:\n"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"ename": "KeyboardInterrupt",
|
| 135 |
+
"evalue": "",
|
| 136 |
+
"output_type": "error",
|
| 137 |
+
"traceback": [
|
| 138 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 139 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 140 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:126\u001b[0m, in \u001b[0;36mflush_figures\u001b[0;34m()\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m InlineBackend\u001b[38;5;241m.\u001b[39minstance()\u001b[38;5;241m.\u001b[39mclose_figures:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;66;03m# ignore the tracking, just draw and close all figures\u001b[39;00m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m show(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 128\u001b[0m \u001b[38;5;66;03m# safely show traceback if in IPython, else raise\u001b[39;00m\n\u001b[1;32m 129\u001b[0m ip \u001b[38;5;241m=\u001b[39m get_ipython()\n",
|
| 141 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:90\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(close, block)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m figure_manager \u001b[38;5;129;01min\u001b[39;00m Gcf\u001b[38;5;241m.\u001b[39mget_all_fig_managers():\n\u001b[0;32m---> 90\u001b[0m display(\n\u001b[1;32m 91\u001b[0m figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure,\n\u001b[1;32m 92\u001b[0m metadata\u001b[38;5;241m=\u001b[39m_fetch_figure_metadata(figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure)\n\u001b[1;32m 93\u001b[0m )\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 95\u001b[0m show\u001b[38;5;241m.\u001b[39m_to_draw \u001b[38;5;241m=\u001b[39m []\n",
|
| 142 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/display_functions.py:298\u001b[0m, in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, raw, clear, *objs, **kwargs)\u001b[0m\n\u001b[1;32m 296\u001b[0m publish_display_data(data\u001b[38;5;241m=\u001b[39mobj, metadata\u001b[38;5;241m=\u001b[39mmetadata, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 298\u001b[0m format_dict, md_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mformat\u001b[39m(obj, include\u001b[38;5;241m=\u001b[39minclude, exclude\u001b[38;5;241m=\u001b[39mexclude)\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m format_dict:\n\u001b[1;32m 300\u001b[0m \u001b[38;5;66;03m# nothing to display (e.g. _ipython_display_ took over)\u001b[39;00m\n\u001b[1;32m 301\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n",
|
| 143 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:179\u001b[0m, in \u001b[0;36mDisplayFormatter.format\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m 177\u001b[0m md \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 179\u001b[0m data \u001b[38;5;241m=\u001b[39m formatter(obj)\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 181\u001b[0m \u001b[38;5;66;03m# FIXME: log the exception\u001b[39;00m\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
|
| 144 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/decorator.py:232\u001b[0m, in \u001b[0;36mdecorate.<locals>.fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kwsyntax:\n\u001b[1;32m 231\u001b[0m args, kw \u001b[38;5;241m=\u001b[39m fix(args, kw, sig)\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m caller(func, \u001b[38;5;241m*\u001b[39m(extras \u001b[38;5;241m+\u001b[39m args), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n",
|
| 145 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:223\u001b[0m, in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"show traceback on failed format call\"\"\"\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 223\u001b[0m r \u001b[38;5;241m=\u001b[39m method(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;66;03m# don't warn on NotImplementedErrors\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_return(\u001b[38;5;28;01mNone\u001b[39;00m, args[\u001b[38;5;241m0\u001b[39m])\n",
|
| 146 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:340\u001b[0m, in \u001b[0;36mBaseFormatter.__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 340\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m printer(obj)\n\u001b[1;32m 341\u001b[0m \u001b[38;5;66;03m# Finally look for special method names\u001b[39;00m\n\u001b[1;32m 342\u001b[0m method \u001b[38;5;241m=\u001b[39m get_real_method(obj, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_method)\n",
|
| 147 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/pylabtools.py:152\u001b[0m, in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, base64, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend_bases\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FigureCanvasBase\n\u001b[1;32m 150\u001b[0m FigureCanvasBase(fig)\n\u001b[0;32m--> 152\u001b[0m fig\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mprint_figure(bytes_io, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n\u001b[1;32m 153\u001b[0m data \u001b[38;5;241m=\u001b[39m bytes_io\u001b[38;5;241m.\u001b[39mgetvalue()\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fmt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msvg\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
|
| 148 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/backend_bases.py:2164\u001b[0m, in \u001b[0;36mFigureCanvasBase.print_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m 2161\u001b[0m \u001b[38;5;66;03m# we do this instead of `self.figure.draw_without_rendering`\u001b[39;00m\n\u001b[1;32m 2162\u001b[0m \u001b[38;5;66;03m# so that we can inject the orientation\u001b[39;00m\n\u001b[1;32m 2163\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(renderer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_draw_disabled\u001b[39m\u001b[38;5;124m\"\u001b[39m, nullcontext)():\n\u001b[0;32m-> 2164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 2165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches:\n\u001b[1;32m 2166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtight\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
|
| 149 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:95\u001b[0m, in \u001b[0;36m_finalize_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(draw)\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdraw_wrapper\u001b[39m(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 95\u001b[0m result \u001b[38;5;241m=\u001b[39m draw(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m renderer\u001b[38;5;241m.\u001b[39m_rasterizing:\n\u001b[1;32m 97\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstop_rasterizing()\n",
|
| 150 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
| 151 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/figure.py:3154\u001b[0m, in \u001b[0;36mFigure.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3151\u001b[0m \u001b[38;5;66;03m# ValueError can occur when resizing a window.\u001b[39;00m\n\u001b[1;32m 3153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpatch\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[0;32m-> 3154\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3155\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3157\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m sfig \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msubfigs:\n\u001b[1;32m 3158\u001b[0m sfig\u001b[38;5;241m.\u001b[39mdraw(renderer)\n",
|
| 152 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
|
| 153 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
| 154 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/axes/_base.py:3070\u001b[0m, in \u001b[0;36m_AxesBase.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3067\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artists_rasterized:\n\u001b[1;32m 3068\u001b[0m _draw_rasterized(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure, artists_rasterized, renderer)\n\u001b[0;32m-> 3070\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3071\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3073\u001b[0m renderer\u001b[38;5;241m.\u001b[39mclose_group(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maxes\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 3074\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstale \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
| 155 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
|
| 156 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
| 157 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:649\u001b[0m, in \u001b[0;36m_ImageBase.draw\u001b[0;34m(self, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 647\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im, trans)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 649\u001b[0m im, l, b, trans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_image(\n\u001b[1;32m 650\u001b[0m renderer, renderer\u001b[38;5;241m.\u001b[39mget_image_magnification())\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m im \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 652\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im)\n",
|
| 158 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:939\u001b[0m, in \u001b[0;36mAxesImage.make_image\u001b[0;34m(self, renderer, magnification, unsampled)\u001b[0m\n\u001b[1;32m 936\u001b[0m transformed_bbox \u001b[38;5;241m=\u001b[39m TransformedBbox(bbox, trans)\n\u001b[1;32m 937\u001b[0m clip \u001b[38;5;241m=\u001b[39m ((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_box() \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes\u001b[38;5;241m.\u001b[39mbbox) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_on()\n\u001b[1;32m 938\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mbbox)\n\u001b[0;32m--> 939\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_image(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_A, bbox, transformed_bbox, clip,\n\u001b[1;32m 940\u001b[0m magnification, unsampled\u001b[38;5;241m=\u001b[39munsampled)\n",
|
| 159 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:526\u001b[0m, in \u001b[0;36m_ImageBase._make_image\u001b[0;34m(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)\u001b[0m\n\u001b[1;32m 521\u001b[0m mask \u001b[38;5;241m=\u001b[39m (np\u001b[38;5;241m.\u001b[39mwhere(A\u001b[38;5;241m.\u001b[39mmask, np\u001b[38;5;241m.\u001b[39mfloat32(np\u001b[38;5;241m.\u001b[39mnan), np\u001b[38;5;241m.\u001b[39mfloat32(\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m A\u001b[38;5;241m.\u001b[39mmask\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m A\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;66;03m# nontrivial mask\u001b[39;00m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m np\u001b[38;5;241m.\u001b[39mones_like(A, np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[1;32m 524\u001b[0m \u001b[38;5;66;03m# we always have to interpolate the mask to account for\u001b[39;00m\n\u001b[1;32m 525\u001b[0m \u001b[38;5;66;03m# non-affine transformations\u001b[39;00m\n\u001b[0;32m--> 526\u001b[0m out_alpha \u001b[38;5;241m=\u001b[39m _resample(\u001b[38;5;28mself\u001b[39m, mask, out_shape, t, resample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m mask \u001b[38;5;66;03m# Make sure we don't use mask anymore!\u001b[39;00m\n\u001b[1;32m 528\u001b[0m \u001b[38;5;66;03m# Agg updates out_alpha in place. If the pixel has no image\u001b[39;00m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;66;03m# data it will not be updated (and still be 0 as we initialized\u001b[39;00m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;66;03m# it), if input data that would go into that output pixel than\u001b[39;00m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;66;03m# it will be `nan`, if all the input data for a pixel is good\u001b[39;00m\n\u001b[1;32m 532\u001b[0m \u001b[38;5;66;03m# it will be 1, and if there is _some_ good data in that output\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;66;03m# pixel it will be between [0, 1] (such as a rotated image).\u001b[39;00m\n",
|
| 160 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:208\u001b[0m, in \u001b[0;36m_resample\u001b[0;34m(image_obj, data, out_shape, transform, resample, alpha)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 207\u001b[0m resample \u001b[38;5;241m=\u001b[39m image_obj\u001b[38;5;241m.\u001b[39mget_resample()\n\u001b[0;32m--> 208\u001b[0m _image\u001b[38;5;241m.\u001b[39mresample(data, out, transform,\n\u001b[1;32m 209\u001b[0m _interpd_[interpolation],\n\u001b[1;32m 210\u001b[0m resample,\n\u001b[1;32m 211\u001b[0m alpha,\n\u001b[1;32m 212\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filternorm(),\n\u001b[1;32m 213\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filterrad())\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
|
| 161 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 162 |
+
]
|
| 163 |
+
}
|
| 164 |
+
],
|
| 165 |
+
"source": [
|
| 166 |
+
"import sigpyproc.readers as r\n",
|
| 167 |
+
"import cv2\n",
|
| 168 |
+
"import numpy as np\n",
|
| 169 |
+
"import matplotlib.pyplot as plt\n",
|
| 170 |
+
"from scipy.special import softmax\n",
|
| 171 |
+
"from tqdm import tqdm\n",
|
| 172 |
+
"%matplotlib inline\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"header = fil.header\n",
|
| 175 |
+
"triggers = []\n",
|
| 176 |
+
"counter = 0\n",
|
| 177 |
+
"for i in tqdm(range(2048,header.nsamples, 2048)):\n",
|
| 178 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
| 179 |
+
" # Shuffle the tensor using the random indices\n",
|
| 180 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
| 181 |
+
" out = softmax(out.detach().cpu().numpy(), axis=1)\n",
|
| 182 |
+
" triggers.append(out)\n",
|
| 183 |
+
" counter += 1\n",
|
| 184 |
+
" # if counter > 1000:\n",
|
| 185 |
+
" # break\n",
|
| 186 |
+
" # if out[0, 1]>0.999:\n",
|
| 187 |
+
" # key = data.cpu().numpy()\n",
|
| 188 |
+
" # plt.figure(figsize=(10,10))\n",
|
| 189 |
+
" # plt.imshow(data.cpu().numpy(), aspect = 10, vmax = 54557.824)\n",
|
| 190 |
+
"stack = np.stack(triggers)\n",
|
| 191 |
+
"positives = stack[:,0,1]\n",
|
| 192 |
+
"num_pos = np.where(positives > 0.999)[0].shape[0]\n",
|
| 193 |
+
"print(num_pos)"
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"cell_type": "code",
|
| 198 |
+
"execution_count": null,
|
| 199 |
+
"id": "64df934d-f4a2-49f0-857d-2661b1d78b21",
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": [
|
| 203 |
+
"np.flipud()"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"id": "1eafb2c1-857e-48be-aa8b-18669c0e0f8c",
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"plt.figure(figsize=(10,10))\n",
|
| 214 |
+
"# plt.imshow(key, aspect = 10, vmax = 54557.824)\n",
|
| 215 |
+
"plt.imshow(key, aspect = 10)"
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"cell_type": "code",
|
| 220 |
+
"execution_count": null,
|
| 221 |
+
"id": "ed3783c3-8ed1-46d6-91e4-e906dfa44913",
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"outputs": [],
|
| 224 |
+
"source": [
|
| 225 |
+
"key.shape"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"id": "8b56a356-a582-4f5d-a8e2-20f725a48fb3",
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"outputs": [],
|
| 234 |
+
"source": [
|
| 235 |
+
"total_data =[]\n",
|
| 236 |
+
"for i in range(32):\n",
|
| 237 |
+
" total_data.append(key)\n",
|
| 238 |
+
"total_data = torch.tensor(np.array(total_data))\n",
|
| 239 |
+
"total_data.cpu().detach().numpy().tofile(\"crab_in.bin\")\n",
|
| 240 |
+
"print(total_data.shape)\n",
|
| 241 |
+
"outputs_data = []\n",
|
| 242 |
+
"for i in range(32):\n",
|
| 243 |
+
" temp = model(transform(total_data.cuda()[i,:,:])[None])\n",
|
| 244 |
+
" print(temp)\n",
|
| 245 |
+
" # outputs_data.append(softmax(temp.detach().cpu().numpy(), axis=1))\n",
|
| 246 |
+
" outputs_data.append(temp.detach().cpu().numpy())\n",
|
| 247 |
+
"outputs_data = torch.tensor(outputs_data)\n",
|
| 248 |
+
"outputs_data.cpu().detach().numpy().tofile(\"crab_out.bin\")"
|
| 249 |
+
]
|
| 250 |
+
}
|
| 251 |
+
],
|
| 252 |
+
"metadata": {
|
| 253 |
+
"kernelspec": {
|
| 254 |
+
"display_name": "Python 3 (ipykernel)",
|
| 255 |
+
"language": "python",
|
| 256 |
+
"name": "python3"
|
| 257 |
+
},
|
| 258 |
+
"language_info": {
|
| 259 |
+
"codemirror_mode": {
|
| 260 |
+
"name": "ipython",
|
| 261 |
+
"version": 3
|
| 262 |
+
},
|
| 263 |
+
"file_extension": ".py",
|
| 264 |
+
"mimetype": "text/x-python",
|
| 265 |
+
"name": "python",
|
| 266 |
+
"nbconvert_exporter": "python",
|
| 267 |
+
"pygments_lexer": "ipython3",
|
| 268 |
+
"version": "3.11.9"
|
| 269 |
+
}
|
| 270 |
+
},
|
| 271 |
+
"nbformat": 4,
|
| 272 |
+
"nbformat_minor": 5
|
| 273 |
+
}
|
models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.ipynb_checkpoints/resnet_model-checkpoint.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ResidualBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
|
| 7 |
+
super(ResidualBlock, self).__init__()
|
| 8 |
+
self.conv1 = nn.Sequential(
|
| 9 |
+
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
|
| 10 |
+
nn.BatchNorm2d(out_channels),
|
| 11 |
+
nn.ReLU())
|
| 12 |
+
self.conv2 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
|
| 14 |
+
nn.BatchNorm2d(out_channels))
|
| 15 |
+
self.downsample = downsample
|
| 16 |
+
self.relu = nn.ReLU()
|
| 17 |
+
self.out_channels = out_channels
|
| 18 |
+
self.dropout_percentage = 0.5
|
| 19 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
| 20 |
+
self.batchnorm_mod = nn.BatchNorm2d(out_channels)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
residual = x
|
| 24 |
+
out = self.conv1(x)
|
| 25 |
+
out = self.dropout1(out)
|
| 26 |
+
# out = self.batchnorm_mod(out)
|
| 27 |
+
out = self.conv2(out)
|
| 28 |
+
out = self.dropout1(out)
|
| 29 |
+
# out = self.batchnorm_mod(out)
|
| 30 |
+
if self.downsample:
|
| 31 |
+
residual = self.downsample(x)
|
| 32 |
+
out += residual
|
| 33 |
+
out = self.relu(out)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResNet(nn.Module):
|
| 38 |
+
def __init__(self, inchan, block, layers, num_classes = 10):
|
| 39 |
+
super(ResNet, self).__init__()
|
| 40 |
+
self.inplanes = 64
|
| 41 |
+
self.eps = 1e-5
|
| 42 |
+
self.relu = nn.ReLU()
|
| 43 |
+
self.conv1 = nn.Sequential(
|
| 44 |
+
nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
|
| 45 |
+
nn.BatchNorm2d(64),
|
| 46 |
+
nn.ReLU())
|
| 47 |
+
self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
|
| 48 |
+
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
|
| 49 |
+
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
|
| 50 |
+
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
|
| 51 |
+
self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
|
| 52 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 53 |
+
self.fc = nn.Linear(39424, num_classes)
|
| 54 |
+
self.dropout_percentage = 0.3
|
| 55 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
| 56 |
+
|
| 57 |
+
# Encoder
|
| 58 |
+
self.encoder = nn.Sequential(
|
| 59 |
+
nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
|
| 60 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 61 |
+
nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
|
| 62 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 63 |
+
nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
|
| 64 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 65 |
+
nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
|
| 66 |
+
nn.Sigmoid()
|
| 67 |
+
)
|
| 68 |
+
params = sum(p.numel() for p in self.encoder.parameters())
|
| 69 |
+
print("num params encoder ",params)
|
| 70 |
+
|
| 71 |
+
def norm(self, x):
|
| 72 |
+
shifted = x-x.min()
|
| 73 |
+
maxes = torch.amax(abs(shifted), dim=(-2, -1))
|
| 74 |
+
repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
|
| 75 |
+
x = shifted/repeated_maxes
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 79 |
+
downsample = None
|
| 80 |
+
if stride != 1 or self.inplanes != planes:
|
| 81 |
+
downsample = nn.Sequential(
|
| 82 |
+
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
|
| 83 |
+
nn.BatchNorm2d(planes),
|
| 84 |
+
)
|
| 85 |
+
layers = []
|
| 86 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 87 |
+
self.inplanes = planes
|
| 88 |
+
for i in range(1, blocks):
|
| 89 |
+
layers.append(block(self.inplanes, planes))
|
| 90 |
+
return nn.Sequential(*layers)
|
| 91 |
+
|
| 92 |
+
def forward(self, x, return_mask=False):
|
| 93 |
+
# x = self.norm(x)
|
| 94 |
+
x = self.conv1(x)
|
| 95 |
+
x = self.maxpool(x)
|
| 96 |
+
x = self.layer0(x)
|
| 97 |
+
x = self.layer1(x)
|
| 98 |
+
x = self.layer2(x)
|
| 99 |
+
x = self.layer3(x)
|
| 100 |
+
x = self.avgpool(x)
|
| 101 |
+
x = x.view(x.size(0), -1)
|
| 102 |
+
x = self.dropout1(x)
|
| 103 |
+
x = self.fc(x)
|
| 104 |
+
# return x
|
| 105 |
+
if return_mask:
|
| 106 |
+
return x, self.mask, self.value
|
| 107 |
+
else:
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ConvAutoencoder(nn.Module):
|
| 112 |
+
def __init__(self):
|
| 113 |
+
super(ConvAutoencoder, self).__init__()
|
| 114 |
+
|
| 115 |
+
# Encoder
|
| 116 |
+
self.encoder = nn.Sequential(
|
| 117 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
|
| 118 |
+
nn.ReLU(),
|
| 119 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
|
| 120 |
+
nn.ReLU(),
|
| 121 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
|
| 122 |
+
nn.ReLU(),
|
| 123 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
|
| 124 |
+
nn.ReLU()
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Fully connected latent space
|
| 128 |
+
self.fc1 = nn.Linear(128 * 12 * 16, 8)
|
| 129 |
+
self.fc2 = nn.Linear(8, 128 * 12 * 16)
|
| 130 |
+
|
| 131 |
+
# Decoder
|
| 132 |
+
self.decoder = nn.Sequential(
|
| 133 |
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
|
| 134 |
+
nn.ReLU(),
|
| 135 |
+
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
|
| 136 |
+
nn.ReLU(),
|
| 137 |
+
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
|
| 138 |
+
nn.ReLU(),
|
| 139 |
+
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
|
| 140 |
+
nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
# Encode
|
| 145 |
+
x = self.encoder(x)
|
| 146 |
+
|
| 147 |
+
# Flatten the encoded output
|
| 148 |
+
x = x.view(x.size(0), -1)
|
| 149 |
+
|
| 150 |
+
# Fully connected latent space
|
| 151 |
+
x = self.fc1(x)
|
| 152 |
+
x = self.fc2(x)
|
| 153 |
+
|
| 154 |
+
# Reshape the output to the shape suitable for the decoder
|
| 155 |
+
x = x.view(x.size(0), 128, 12, 16)
|
| 156 |
+
|
| 157 |
+
# Decode
|
| 158 |
+
x = self.decoder(x)
|
| 159 |
+
|
| 160 |
+
return x
|
models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ResidualBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
|
| 7 |
+
super(ResidualBlock, self).__init__()
|
| 8 |
+
self.conv1 = nn.Sequential(
|
| 9 |
+
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
|
| 10 |
+
nn.BatchNorm2d(out_channels),
|
| 11 |
+
nn.ReLU())
|
| 12 |
+
self.conv2 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
|
| 14 |
+
nn.BatchNorm2d(out_channels))
|
| 15 |
+
self.downsample = downsample
|
| 16 |
+
self.relu = nn.ReLU()
|
| 17 |
+
self.out_channels = out_channels
|
| 18 |
+
self.dropout_percentage = 0.5
|
| 19 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
| 20 |
+
self.batchnorm_mod = nn.BatchNorm2d(out_channels)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
residual = x
|
| 24 |
+
out = self.conv1(x)
|
| 25 |
+
out = self.dropout1(out)
|
| 26 |
+
# out = self.batchnorm_mod(out)
|
| 27 |
+
out = self.conv2(out)
|
| 28 |
+
out = self.dropout1(out)
|
| 29 |
+
# out = self.batchnorm_mod(out)
|
| 30 |
+
if self.downsample:
|
| 31 |
+
residual = self.downsample(x)
|
| 32 |
+
out += residual
|
| 33 |
+
out = self.relu(out)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResNet(nn.Module):
|
| 38 |
+
def __init__(self, inchan, block, layers, num_classes = 10):
|
| 39 |
+
super(ResNet, self).__init__()
|
| 40 |
+
self.inplanes = 64
|
| 41 |
+
self.eps = 1e-5
|
| 42 |
+
self.relu = nn.ReLU()
|
| 43 |
+
self.conv1 = nn.Sequential(
|
| 44 |
+
nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
|
| 45 |
+
nn.BatchNorm2d(64),
|
| 46 |
+
nn.ReLU())
|
| 47 |
+
self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
|
| 48 |
+
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
|
| 49 |
+
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
|
| 50 |
+
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
|
| 51 |
+
self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
|
| 52 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 53 |
+
self.fc = nn.Linear(39424, num_classes)
|
| 54 |
+
self.dropout_percentage = 0.3
|
| 55 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
| 56 |
+
|
| 57 |
+
# Encoder
|
| 58 |
+
self.encoder = nn.Sequential(
|
| 59 |
+
nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
|
| 60 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 61 |
+
nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
|
| 62 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 63 |
+
nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
|
| 64 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
| 65 |
+
nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
|
| 66 |
+
nn.Sigmoid()
|
| 67 |
+
)
|
| 68 |
+
params = sum(p.numel() for p in self.encoder.parameters())
|
| 69 |
+
print("num params encoder ",params)
|
| 70 |
+
|
| 71 |
+
def norm(self, x):
|
| 72 |
+
shifted = x-x.min()
|
| 73 |
+
maxes = torch.amax(abs(shifted), dim=(-2, -1))
|
| 74 |
+
repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
|
| 75 |
+
x = shifted/repeated_maxes
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 79 |
+
downsample = None
|
| 80 |
+
if stride != 1 or self.inplanes != planes:
|
| 81 |
+
downsample = nn.Sequential(
|
| 82 |
+
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
|
| 83 |
+
nn.BatchNorm2d(planes),
|
| 84 |
+
)
|
| 85 |
+
layers = []
|
| 86 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 87 |
+
self.inplanes = planes
|
| 88 |
+
for i in range(1, blocks):
|
| 89 |
+
layers.append(block(self.inplanes, planes))
|
| 90 |
+
return nn.Sequential(*layers)
|
| 91 |
+
|
| 92 |
+
def forward(self, x, return_mask=False):
|
| 93 |
+
# # m = self.encoder(x).unsqueeze(-1).repeat(1, 1, 1, x.shape[-1])
|
| 94 |
+
m = self.encoder(x)
|
| 95 |
+
self.mask = m
|
| 96 |
+
self.value = x
|
| 97 |
+
# # m = nn.Sigmoid()(self.encoder(x))
|
| 98 |
+
x = x * m
|
| 99 |
+
# x = self.norm(x)
|
| 100 |
+
x = self.conv1(x)
|
| 101 |
+
x = self.maxpool(x)
|
| 102 |
+
x = self.layer0(x)
|
| 103 |
+
x = self.layer1(x)
|
| 104 |
+
x = self.layer2(x)
|
| 105 |
+
x = self.layer3(x)
|
| 106 |
+
x = self.avgpool(x)
|
| 107 |
+
x = x.view(x.size(0), -1)
|
| 108 |
+
x = self.dropout1(x)
|
| 109 |
+
x = self.fc(x)
|
| 110 |
+
return x
|
| 111 |
+
# if return_mask:
|
| 112 |
+
# return x, self.mask, self.value
|
| 113 |
+
# else:
|
| 114 |
+
# return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ConvAutoencoder(nn.Module):
|
| 118 |
+
def __init__(self):
|
| 119 |
+
super(ConvAutoencoder, self).__init__()
|
| 120 |
+
|
| 121 |
+
# Encoder
|
| 122 |
+
self.encoder = nn.Sequential(
|
| 123 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
|
| 124 |
+
nn.ReLU(),
|
| 125 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
|
| 126 |
+
nn.ReLU(),
|
| 127 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
|
| 128 |
+
nn.ReLU(),
|
| 129 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
|
| 130 |
+
nn.ReLU()
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Fully connected latent space
|
| 134 |
+
self.fc1 = nn.Linear(128 * 12 * 16, 8)
|
| 135 |
+
self.fc2 = nn.Linear(8, 128 * 12 * 16)
|
| 136 |
+
|
| 137 |
+
# Decoder
|
| 138 |
+
self.decoder = nn.Sequential(
|
| 139 |
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
|
| 140 |
+
nn.ReLU(),
|
| 141 |
+
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
|
| 142 |
+
nn.ReLU(),
|
| 143 |
+
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
|
| 144 |
+
nn.ReLU(),
|
| 145 |
+
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
|
| 146 |
+
nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
# Encode
|
| 151 |
+
x = self.encoder(x)
|
| 152 |
+
|
| 153 |
+
# Flatten the encoded output
|
| 154 |
+
x = x.view(x.size(0), -1)
|
| 155 |
+
|
| 156 |
+
# Fully connected latent space
|
| 157 |
+
x = self.fc1(x)
|
| 158 |
+
x = self.fc2(x)
|
| 159 |
+
|
| 160 |
+
# Reshape the output to the shape suitable for the decoder
|
| 161 |
+
x = x.view(x.size(0), 128, 12, 16)
|
| 162 |
+
|
| 163 |
+
# Decode
|
| 164 |
+
x = self.decoder(x)
|
| 165 |
+
|
| 166 |
+
return x
|
models/.ipynb_checkpoints/train-checkpoint.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from resnet_model import ResidualBlock, ResNet
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import tqdm
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 12 |
+
import pickle
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
ind = int(sys.argv[1])
|
| 16 |
+
seeds = [1,42,7109,2002,32]
|
| 17 |
+
seed = seeds[ind]
|
| 18 |
+
print("using seed: ",seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
num_gpus = torch.cuda.device_count()
|
| 24 |
+
print(num_gpus)
|
| 25 |
+
|
| 26 |
+
# Create custom dataset instance
|
| 27 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
| 28 |
+
dataset = CustomDataset(data_dir, transform=transform)
|
| 29 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
| 30 |
+
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
num_classes = 2
|
| 34 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
| 35 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
| 36 |
+
|
| 37 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
| 38 |
+
model = nn.DataParallel(model)
|
| 39 |
+
model = model.to(device)
|
| 40 |
+
params = sum(p.numel() for p in model.parameters())
|
| 41 |
+
print("num params ",params)
|
| 42 |
+
torch.save(model.state_dict(), 'models/test.pt')
|
| 43 |
+
model.load_state_dict(torch.load('models/test.pt'))
|
| 44 |
+
|
| 45 |
+
preproc_model = preproc()
|
| 46 |
+
Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
| 47 |
+
Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
| 48 |
+
|
| 49 |
+
# Define optimizer and loss function
|
| 50 |
+
|
| 51 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
| 52 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
| 53 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
from tqdm import tqdm
|
| 57 |
+
# Training loop
|
| 58 |
+
epochs = 10000
|
| 59 |
+
for epoch in range(epochs):
|
| 60 |
+
running_loss = 0.0
|
| 61 |
+
correct_train = 0
|
| 62 |
+
total_train = 0
|
| 63 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
| 64 |
+
model.train()
|
| 65 |
+
for i, (images, labels) in enumerate(tepoch):
|
| 66 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 67 |
+
optimizer.zero_grad()
|
| 68 |
+
outputs = model(inputs, return_mask=False).to(device)
|
| 69 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
| 70 |
+
loss = criterion(outputs, new_label)
|
| 71 |
+
loss.backward()
|
| 72 |
+
optimizer.step()
|
| 73 |
+
running_loss += loss.item()
|
| 74 |
+
# Calculate training accuracy
|
| 75 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 76 |
+
total_train += labels.size(0)
|
| 77 |
+
correct_train += (predicted == labels).sum().item()
|
| 78 |
+
val_loss = 0.0
|
| 79 |
+
correct_valid = 0
|
| 80 |
+
total = 0
|
| 81 |
+
model.eval()
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
for images, labels in validloader:
|
| 84 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 85 |
+
optimizer.zero_grad()
|
| 86 |
+
outputs = model(inputs, return_mask=False)
|
| 87 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
| 88 |
+
loss = criterion(outputs, new_label)
|
| 89 |
+
val_loss += loss.item()
|
| 90 |
+
_, predicted = torch.max(outputs, 1)
|
| 91 |
+
total += labels.size(0)
|
| 92 |
+
correct_valid += (predicted == labels).sum().item()
|
| 93 |
+
scheduler.step(val_loss)
|
| 94 |
+
# Calculate training accuracy after each epoch
|
| 95 |
+
train_accuracy = 100 * correct_train / total_train
|
| 96 |
+
val_accuracy = correct_valid / total * 100.0
|
| 97 |
+
torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt')
|
| 98 |
+
Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs)
|
| 99 |
+
|
| 100 |
+
print("===========================")
|
| 101 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
| 102 |
+
print('learning rate: ', scheduler.get_last_lr())
|
| 103 |
+
print("===========================")
|
| 104 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
| 105 |
+
break
|
models/.ipynb_checkpoints/train-mask-8-checkpoint.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from resnet_model_mask import ResidualBlock, ResNet
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import tqdm
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 12 |
+
import pickle
|
| 13 |
+
import sys
|
| 14 |
+
# [1,42,7109,2002,32]
|
| 15 |
+
ind = int(sys.argv[1])
|
| 16 |
+
seeds = [1,42,7109,2002,32]
|
| 17 |
+
seed = seeds[ind]
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
num_gpus = torch.cuda.device_count()
|
| 22 |
+
print(num_gpus)
|
| 23 |
+
|
| 24 |
+
# Create custom dataset instance
|
| 25 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
| 26 |
+
dataset = CustomDataset(data_dir, bit8 = True, transform=transform)
|
| 27 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
| 28 |
+
valid_dataset = CustomDataset(valid_data_dir, bit8 = True, transform=transform)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
num_classes = 2
|
| 32 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
| 33 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
| 34 |
+
|
| 35 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
| 36 |
+
model = nn.DataParallel(model)
|
| 37 |
+
model = model.to(device)
|
| 38 |
+
params = sum(p.numel() for p in model.parameters())
|
| 39 |
+
print("num params ",params)
|
| 40 |
+
torch.save(model.state_dict(), f'models_8/test_{seed}.pt')
|
| 41 |
+
model.load_state_dict(torch.load(f'models_8/test_{seed}.pt'))
|
| 42 |
+
|
| 43 |
+
preproc_model = preproc()
|
| 44 |
+
Convert_ONNX(model.module,f'models_8/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
| 45 |
+
Convert_ONNX(preproc_model,f'models_8/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
| 46 |
+
|
| 47 |
+
# Define optimizer and loss function
|
| 48 |
+
|
| 49 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
| 50 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
| 51 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
from tqdm import tqdm
|
| 55 |
+
# Training loop
|
| 56 |
+
epochs = 10000
|
| 57 |
+
for epoch in range(epochs):
|
| 58 |
+
running_loss = 0.0
|
| 59 |
+
correct_train = 0
|
| 60 |
+
total_train = 0
|
| 61 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
| 62 |
+
model.train()
|
| 63 |
+
for i, (images, labels) in enumerate(tepoch):
|
| 64 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 65 |
+
optimizer.zero_grad()
|
| 66 |
+
outputs = model(inputs, return_mask=False).to(device)
|
| 67 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
| 68 |
+
loss = criterion(outputs, new_label)
|
| 69 |
+
loss.backward()
|
| 70 |
+
optimizer.step()
|
| 71 |
+
running_loss += loss.item()
|
| 72 |
+
# Calculate training accuracy
|
| 73 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 74 |
+
total_train += labels.size(0)
|
| 75 |
+
correct_train += (predicted == labels).sum().item()
|
| 76 |
+
val_loss = 0.0
|
| 77 |
+
correct_valid = 0
|
| 78 |
+
total = 0
|
| 79 |
+
model.eval()
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
for images, labels in validloader:
|
| 82 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 83 |
+
optimizer.zero_grad()
|
| 84 |
+
outputs = model(inputs, return_mask=False)
|
| 85 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
| 86 |
+
loss = criterion(outputs, new_label)
|
| 87 |
+
val_loss += loss.item()
|
| 88 |
+
_, predicted = torch.max(outputs, 1)
|
| 89 |
+
total += labels.size(0)
|
| 90 |
+
correct_valid += (predicted == labels).sum().item()
|
| 91 |
+
scheduler.step(val_loss)
|
| 92 |
+
# Calculate training accuracy after each epoch
|
| 93 |
+
train_accuracy = 100 * correct_train / total_train
|
| 94 |
+
val_accuracy = correct_valid / total * 100.0
|
| 95 |
+
torch.save(model.state_dict(), 'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
|
| 96 |
+
Convert_ONNX(model.module,'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
|
| 97 |
+
|
| 98 |
+
print("===========================")
|
| 99 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
| 100 |
+
print('learning rate: ', scheduler.get_last_lr())
|
| 101 |
+
print("===========================")
|
| 102 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
| 103 |
+
break
|
models/.ipynb_checkpoints/train-mask-checkpoint.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from resnet_model_mask import ResidualBlock, ResNet
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import tqdm
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 12 |
+
import pickle
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
ind = int(sys.argv[1])
|
| 16 |
+
seeds = [1,42,7109,2002,32]
|
| 17 |
+
seed = seeds[ind]
|
| 18 |
+
print("using seed: ",seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
|
| 21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
num_gpus = torch.cuda.device_count()
|
| 23 |
+
print(num_gpus)
|
| 24 |
+
|
| 25 |
+
# Create custom dataset instance
|
| 26 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
| 27 |
+
dataset = CustomDataset(data_dir, transform=transform)
|
| 28 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
| 29 |
+
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
num_classes = 2
|
| 33 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
| 34 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
| 35 |
+
|
| 36 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
| 37 |
+
model = nn.DataParallel(model)
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
params = sum(p.numel() for p in model.parameters())
|
| 40 |
+
print("num params ",params)
|
| 41 |
+
torch.save(model.state_dict(), f'models_mask/test_{seed}.pt')
|
| 42 |
+
model.load_state_dict(torch.load(f'models_mask/test_{seed}.pt'))
|
| 43 |
+
|
| 44 |
+
preproc_model = preproc()
|
| 45 |
+
Convert_ONNX(model.module,f'models_mask/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
| 46 |
+
Convert_ONNX(preproc_model,f'models_mask/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
| 47 |
+
|
| 48 |
+
# Define optimizer and loss function
|
| 49 |
+
|
| 50 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
| 51 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
| 52 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
from tqdm import tqdm
|
| 56 |
+
# Training loop
|
| 57 |
+
epochs = 10000
|
| 58 |
+
for epoch in range(epochs):
|
| 59 |
+
running_loss = 0.0
|
| 60 |
+
correct_train = 0
|
| 61 |
+
total_train = 0
|
| 62 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
| 63 |
+
model.train()
|
| 64 |
+
for i, (images, labels) in enumerate(tepoch):
|
| 65 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 66 |
+
optimizer.zero_grad()
|
| 67 |
+
outputs = model(inputs, return_mask=False).to(device)
|
| 68 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
| 69 |
+
loss = criterion(outputs, new_label)
|
| 70 |
+
loss.backward()
|
| 71 |
+
optimizer.step()
|
| 72 |
+
running_loss += loss.item()
|
| 73 |
+
# Calculate training accuracy
|
| 74 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 75 |
+
total_train += labels.size(0)
|
| 76 |
+
correct_train += (predicted == labels).sum().item()
|
| 77 |
+
val_loss = 0.0
|
| 78 |
+
correct_valid = 0
|
| 79 |
+
total = 0
|
| 80 |
+
model.eval()
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
for images, labels in validloader:
|
| 83 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
| 84 |
+
optimizer.zero_grad()
|
| 85 |
+
outputs = model(inputs, return_mask=False)
|
| 86 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
| 87 |
+
loss = criterion(outputs, new_label)
|
| 88 |
+
val_loss += loss.item()
|
| 89 |
+
_, predicted = torch.max(outputs, 1)
|
| 90 |
+
total += labels.size(0)
|
| 91 |
+
correct_valid += (predicted == labels).sum().item()
|
| 92 |
+
scheduler.step(val_loss)
|
| 93 |
+
# Calculate training accuracy after each epoch
|
| 94 |
+
train_accuracy = 100 * correct_train / total_train
|
| 95 |
+
val_accuracy = correct_valid / total * 100.0
|
| 96 |
+
torch.save(model.state_dict(), 'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
|
| 97 |
+
Convert_ONNX(model.module,'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
|
| 98 |
+
|
| 99 |
+
print("===========================")
|
| 100 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
| 101 |
+
print('learning rate: ', scheduler.get_last_lr())
|
| 102 |
+
print("===========================")
|
| 103 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
| 104 |
+
break
|
models/.ipynb_checkpoints/utils-checkpoint.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from blimpy import Waterfall
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from sigpyproc.readers import FilReader
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_pickled_data(file_path):
|
| 15 |
+
with open(file_path, 'rb') as f:
|
| 16 |
+
data = pickle.load(f)
|
| 17 |
+
return data
|
| 18 |
+
|
| 19 |
+
# Custom dataset class
|
| 20 |
+
class CustomDataset(Dataset):
|
| 21 |
+
def __init__(self, data_dir, bit8=False, transform=None):
|
| 22 |
+
self.data_dir = data_dir
|
| 23 |
+
self.transform = transform
|
| 24 |
+
self.images = []
|
| 25 |
+
self.labels = []
|
| 26 |
+
self.classes = os.listdir(data_dir)
|
| 27 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
| 28 |
+
self.bit8 = bit8
|
| 29 |
+
# Load images and labels
|
| 30 |
+
for cls in self.classes:
|
| 31 |
+
class_dir = os.path.join(data_dir, cls)
|
| 32 |
+
for image_name in os.listdir(class_dir):
|
| 33 |
+
image_path = os.path.join(class_dir, image_name)
|
| 34 |
+
self.images.append(image_path)
|
| 35 |
+
self.labels.append(self.class_to_idx[cls])
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.images)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
image_path = self.images[idx]
|
| 42 |
+
label = self.labels[idx]
|
| 43 |
+
# Load image
|
| 44 |
+
image = load_pickled_data(image_path)
|
| 45 |
+
if self.transform is not None:
|
| 46 |
+
if self.bit8 == True:
|
| 47 |
+
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
|
| 48 |
+
else:
|
| 49 |
+
new_image = self.transform(torch.from_numpy(image['data']))
|
| 50 |
+
# new_image = self.transform(image['data'])
|
| 51 |
+
return new_image, label
|
| 52 |
+
|
| 53 |
+
# Custom dataset class
|
| 54 |
+
class CustomDataset_Masked(Dataset):
|
| 55 |
+
def __init__(self, data_dir, transform=None):
|
| 56 |
+
self.data_dir = data_dir
|
| 57 |
+
self.transform = transform
|
| 58 |
+
self.images = []
|
| 59 |
+
self.labels = []
|
| 60 |
+
self.classes = os.listdir(data_dir)
|
| 61 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
| 62 |
+
|
| 63 |
+
# Load images and labels
|
| 64 |
+
for cls in self.classes:
|
| 65 |
+
class_dir = os.path.join(data_dir, cls)
|
| 66 |
+
for image_name in os.listdir(class_dir):
|
| 67 |
+
image_path = os.path.join(class_dir, image_name)
|
| 68 |
+
self.images.append(image_path)
|
| 69 |
+
self.labels.append(self.class_to_idx[cls])
|
| 70 |
+
|
| 71 |
+
def __len__(self):
|
| 72 |
+
return len(self.images)
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, idx):
|
| 75 |
+
image_path = self.images[idx]
|
| 76 |
+
|
| 77 |
+
label = self.labels[idx]
|
| 78 |
+
# Load image
|
| 79 |
+
image = load_pickled_data(image_path)
|
| 80 |
+
if self.transform is not None:
|
| 81 |
+
if image['burst'].max() ==0:
|
| 82 |
+
new_burst = torch.from_numpy(image['burst'])
|
| 83 |
+
else:
|
| 84 |
+
new_burst = torch.from_numpy(image['burst']/image['burst'].max())
|
| 85 |
+
ind = new_burst > 0.1
|
| 86 |
+
ind_not = new_burst <= 0.1
|
| 87 |
+
new_burst[ind] = 1
|
| 88 |
+
new_burst[ind_not] = 0
|
| 89 |
+
new_image = self.transform(torch.from_numpy(image['data'].data))
|
| 90 |
+
new_burst_arr = torch.zeros_like(new_image)
|
| 91 |
+
new_burst_arr[ 0, :,:] = new_burst
|
| 92 |
+
new_burst_arr[ 1, :,:] = new_burst
|
| 93 |
+
new_burst_arr[ 2, :,:] = new_burst
|
| 94 |
+
return new_image, label, new_burst_arr
|
| 95 |
+
|
| 96 |
+
# Custom dataset class
|
| 97 |
+
class TestingDataset(Dataset):
|
| 98 |
+
def __init__(self, data_dir, bit8=False, transform=None):
|
| 99 |
+
self.data_dir = data_dir
|
| 100 |
+
self.transform = transform
|
| 101 |
+
self.images = []
|
| 102 |
+
self.labels = []
|
| 103 |
+
self.classes = os.listdir(data_dir)
|
| 104 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
| 105 |
+
self.bit8 = bit8
|
| 106 |
+
# Load images and labels
|
| 107 |
+
for cls in self.classes:
|
| 108 |
+
class_dir = os.path.join(data_dir, cls)
|
| 109 |
+
for image_name in os.listdir(class_dir):
|
| 110 |
+
image_path = os.path.join(class_dir, image_name)
|
| 111 |
+
self.images.append(image_path)
|
| 112 |
+
self.labels.append(self.class_to_idx[cls])
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.images)
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, idx):
|
| 118 |
+
image_path = self.images[idx]
|
| 119 |
+
label = self.labels[idx]
|
| 120 |
+
# Load image
|
| 121 |
+
image = load_pickled_data(image_path)
|
| 122 |
+
params = image['params']
|
| 123 |
+
if self.transform is not None:
|
| 124 |
+
params = image['params']
|
| 125 |
+
if self.bit8 == True:
|
| 126 |
+
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
|
| 127 |
+
else:
|
| 128 |
+
new_image = self.transform(torch.from_numpy(image['data']))
|
| 129 |
+
params['labels'] = label
|
| 130 |
+
return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard'])
|
| 131 |
+
|
| 132 |
+
# Custom dataset class
|
| 133 |
+
class SearchDataset(Dataset):
|
| 134 |
+
def __init__(self, data_dir, transform=None, pickle_data=False):
|
| 135 |
+
self.window_size = 2048
|
| 136 |
+
|
| 137 |
+
if pickle_data:
|
| 138 |
+
with open(data_dir, 'rb') as f:
|
| 139 |
+
self.d = pickle.load(f)
|
| 140 |
+
self.header = self.d['header']
|
| 141 |
+
self.images = self.crop(self.d['data'][:,0,:], self.window_size)
|
| 142 |
+
else:
|
| 143 |
+
self.obs = Waterfall(data_dir, max_load = 50)
|
| 144 |
+
self.header = self.obs.header
|
| 145 |
+
self.images = self.crop(self.obs.data[:,0,:], self.window_size)
|
| 146 |
+
self.transform = transform
|
| 147 |
+
self.SEC_PER_DAY = 86400
|
| 148 |
+
|
| 149 |
+
def crop(self, data, window_size = 2048):
|
| 150 |
+
n_samp = data.shape[0]//window_size
|
| 151 |
+
new_data = np.zeros((n_samp, window_size, 192 ))
|
| 152 |
+
for i in range(n_samp):
|
| 153 |
+
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
|
| 154 |
+
return new_data
|
| 155 |
+
|
| 156 |
+
def __len__(self):
|
| 157 |
+
return self.images.shape[0]
|
| 158 |
+
def __getitem__(self, idx):
|
| 159 |
+
data = self.images[idx, :, :].T
|
| 160 |
+
tindex = idx * self.window_size
|
| 161 |
+
time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart']
|
| 162 |
+
if self.transform is not None:
|
| 163 |
+
new_image = self.transform(data)
|
| 164 |
+
return new_image, idx
|
| 165 |
+
|
| 166 |
+
# Custom dataset class
|
| 167 |
+
class SearchDataset_Sigproc(Dataset):
|
| 168 |
+
def __init__(self, data_dir, transform=None):
|
| 169 |
+
self.window_size = 2048
|
| 170 |
+
fil = FilReader(data_dir)
|
| 171 |
+
self.header = fil.header
|
| 172 |
+
# print("check shape ",fil.read_block(0, fil.header.nsamples).shape)
|
| 173 |
+
read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024]
|
| 174 |
+
read_data = np.swapaxes(read_data, 0,-1)
|
| 175 |
+
self.images = self.crop(read_data, self.window_size)
|
| 176 |
+
self.transform = transform
|
| 177 |
+
self.SEC_PER_DAY = 86400
|
| 178 |
+
|
| 179 |
+
def crop(self, data, window_size = 2048):
|
| 180 |
+
n_samp = data.shape[0]//window_size
|
| 181 |
+
new_data = np.zeros((n_samp, window_size, 192 ))
|
| 182 |
+
for i in range(n_samp):
|
| 183 |
+
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
|
| 184 |
+
return new_data
|
| 185 |
+
|
| 186 |
+
def __len__(self):
|
| 187 |
+
return self.images.shape[0]
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, idx):
|
| 190 |
+
data = self.images[idx, :, :].T
|
| 191 |
+
tindex = idx * self.window_size
|
| 192 |
+
time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart
|
| 193 |
+
if self.transform is not None:
|
| 194 |
+
new_image = self.transform(torch.from_numpy(data))
|
| 195 |
+
return new_image, idx
|
| 196 |
+
|
| 197 |
+
# def renorm(data):
|
| 198 |
+
# shifted = data - data.min()
|
| 199 |
+
# shifted = shifted/shifted.max()
|
| 200 |
+
# return shifted
|
| 201 |
+
|
| 202 |
+
def renorm(data):
|
| 203 |
+
mean = torch.mean(data)
|
| 204 |
+
std = torch.std(data)
|
| 205 |
+
# Standardize the data
|
| 206 |
+
standardized_data = (data - mean) / std
|
| 207 |
+
return standardized_data
|
| 208 |
+
|
| 209 |
+
def transform(data):
|
| 210 |
+
copy_data = data.detach().clone()
|
| 211 |
+
rms = torch.std(data)
|
| 212 |
+
mean = torch.mean(data)
|
| 213 |
+
masks_rms = [-1, 5]
|
| 214 |
+
new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
|
| 215 |
+
new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
|
| 216 |
+
for i in range(1, len(masks_rms)+1):
|
| 217 |
+
scale = masks_rms[i-1]
|
| 218 |
+
copy_data = data.detach().clone() #deepcopy(data)
|
| 219 |
+
if scale < 0:
|
| 220 |
+
ind = copy_data < abs(scale) * rms + mean
|
| 221 |
+
copy_data[ind] = 0
|
| 222 |
+
else:
|
| 223 |
+
ind = copy_data > (scale) * rms + mean
|
| 224 |
+
copy_data[ind] = 0
|
| 225 |
+
new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
|
| 226 |
+
new_data = new_data.type(torch.float32)
|
| 227 |
+
slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
|
| 228 |
+
new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
|
| 229 |
+
new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
|
| 230 |
+
return new_data
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def renorm_batched(data):
|
| 234 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
| 235 |
+
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
| 236 |
+
standardized_data = (data - mean) / std
|
| 237 |
+
return standardized_data
|
| 238 |
+
|
| 239 |
+
def transform_batched(data):
|
| 240 |
+
copy_data = data.detach().clone()
|
| 241 |
+
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
|
| 242 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
|
| 243 |
+
masks_rms = [-1, 5]
|
| 244 |
+
|
| 245 |
+
# Prepare the new_data tensor
|
| 246 |
+
num_masks = len(masks_rms) + 1
|
| 247 |
+
new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
|
| 248 |
+
|
| 249 |
+
# First layer: Apply renorm(log10(copy_data + epsilon))
|
| 250 |
+
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
|
| 251 |
+
for i, scale in enumerate(masks_rms, start=1):
|
| 252 |
+
copy_data = data.detach().clone()
|
| 253 |
+
|
| 254 |
+
# Apply masking based on the scale
|
| 255 |
+
if scale < 0:
|
| 256 |
+
ind = copy_data < abs(scale) * rms + mean
|
| 257 |
+
else:
|
| 258 |
+
ind = copy_data > scale * rms + mean
|
| 259 |
+
copy_data[ind] = 0
|
| 260 |
+
|
| 261 |
+
# Renormalize and log10 transform
|
| 262 |
+
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
|
| 263 |
+
|
| 264 |
+
# Convert to float32
|
| 265 |
+
new_data = new_data.type(torch.float32)
|
| 266 |
+
|
| 267 |
+
# Chunk along the last dimension and stack
|
| 268 |
+
slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
|
| 269 |
+
new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
|
| 270 |
+
new_data = torch.swapaxes(new_data, 0,1)
|
| 271 |
+
# Reshape into final format
|
| 272 |
+
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
|
| 273 |
+
return new_data
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class preproc(nn.Module):
|
| 278 |
+
def forward(self, x, flip=True):
|
| 279 |
+
if flip:
|
| 280 |
+
transform_batched(torch.flip(x, dims = (-2,)))
|
| 281 |
+
else:
|
| 282 |
+
transform_batched(x)
|
| 283 |
+
return template
|
| 284 |
+
|
| 285 |
+
# class preproc_debug(nn.Module):
|
| 286 |
+
# def forward(self, x):
|
| 287 |
+
# template = torch.zeros((32, 24, 192, 256))
|
| 288 |
+
# # for i in torch.arange(x.shape[0]): # Use a tensor-based range
|
| 289 |
+
# template[0,:,:,:] = transform_debug(torch.flip(x[0,:,:], dims = (0,)))
|
| 290 |
+
# template[1,:,:,:] = transform_debug(torch.flip(x[1,:,:], dims = (0,)))
|
| 291 |
+
# template[2,:,:,:] = transform_debug(torch.flip(x[2,:,:], dims = (0,)))
|
| 292 |
+
# template[3,:,:,:] = transform_debug(torch.flip(x[3,:,:], dims = (0,)))
|
| 293 |
+
# template[4,:,:,:] = transform_debug(torch.flip(x[4,:,:], dims = (0,)))
|
| 294 |
+
# template[5,:,:,:] = transform_debug(torch.flip(x[5,:,:], dims = (0,)))
|
| 295 |
+
# template[6,:,:,:] = transform_debug(torch.flip(x[6,:,:], dims = (0,)))
|
| 296 |
+
# template[7,:,:,:] = transform_debug(torch.flip(x[7,:,:], dims = (0,)))
|
| 297 |
+
# template[8,:,:,:] = transform_debug(torch.flip(x[8,:,:], dims = (0,)))
|
| 298 |
+
# template[9,:,:,:] = transform_debug(torch.flip(x[9,:,:], dims = (0,)))
|
| 299 |
+
# template[10,:,:,:] = transform_debug(torch.flip(x[10,:,:], dims = (0,)))
|
| 300 |
+
# template[11,:,:,:] = transform_debug(torch.flip(x[11,:,:], dims = (0,)))
|
| 301 |
+
# template[12,:,:,:] = transform_debug(torch.flip(x[12,:,:], dims = (0,)))
|
| 302 |
+
# template[13,:,:,:] = transform_debug(torch.flip(x[13,:,:], dims = (0,)))
|
| 303 |
+
# template[14,:,:,:] = transform_debug(torch.flip(x[14,:,:], dims = (0,)))
|
| 304 |
+
# template[15,:,:,:] = transform_debug(torch.flip(x[15,:,:], dims = (0,)))
|
| 305 |
+
# template[16,:,:,:] = transform_debug(torch.flip(x[16,:,:], dims = (0,)))
|
| 306 |
+
# template[17,:,:,:] = transform_debug(torch.flip(x[17,:,:], dims = (0,)))
|
| 307 |
+
# template[18,:,:,:] = transform_debug(torch.flip(x[18,:,:], dims = (0,)))
|
| 308 |
+
# template[19,:,:,:] = transform_debug(torch.flip(x[19,:,:], dims = (0,)))
|
| 309 |
+
# template[20,:,:,:] = transform_debug(torch.flip(x[20,:,:], dims = (0,)))
|
| 310 |
+
# template[21,:,:,:] = transform_debug(torch.flip(x[21,:,:], dims = (0,)))
|
| 311 |
+
# template[22,:,:,:] = transform_debug(torch.flip(x[22,:,:], dims = (0,)))
|
| 312 |
+
# template[23,:,:,:] = transform_debug(torch.flip(x[23,:,:], dims = (0,)))
|
| 313 |
+
# template[24,:,:,:] = transform_debug(torch.flip(x[24,:,:], dims = (0,)))
|
| 314 |
+
# template[25,:,:,:] = transform_debug(torch.flip(x[25,:,:], dims = (0,)))
|
| 315 |
+
# template[26,:,:,:] = transform_debug(torch.flip(x[26,:,:], dims = (0,)))
|
| 316 |
+
# template[27,:,:,:] = transform_debug(torch.flip(x[27,:,:], dims = (0,)))
|
| 317 |
+
# template[28,:,:,:] = transform_debug(torch.flip(x[28,:,:], dims = (0,)))
|
| 318 |
+
# template[29,:,:,:] = transform_debug(torch.flip(x[29,:,:], dims = (0,)))
|
| 319 |
+
# template[30,:,:,:] = transform_debug(torch.flip(x[30,:,:], dims = (0,)))
|
| 320 |
+
# template[31,:,:,:] = transform_debug(torch.flip(x[31,:,:], dims = (0,)))
|
| 321 |
+
# return template
|
| 322 |
+
|
| 323 |
+
# def transform_debug(data):
|
| 324 |
+
# copy_data = data.detach().clone()
|
| 325 |
+
# rms = torch.std(data)
|
| 326 |
+
# mean = torch.mean(data)
|
| 327 |
+
# masks_rms = [-1, 5]
|
| 328 |
+
# new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
|
| 329 |
+
# new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
|
| 330 |
+
# for i in range(1, len(masks_rms)+1):
|
| 331 |
+
# scale = masks_rms[i-1]
|
| 332 |
+
# copy_data = data.detach().clone()
|
| 333 |
+
# if scale < 0:
|
| 334 |
+
# ind = copy_data < abs(scale) * rms + mean
|
| 335 |
+
# copy_data[ind] = 0
|
| 336 |
+
# else:
|
| 337 |
+
# ind = copy_data > (scale) * rms + mean
|
| 338 |
+
# copy_data[ind] = 0
|
| 339 |
+
# new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
|
| 340 |
+
# new_data = new_data.type(torch.float32)
|
| 341 |
+
# slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
|
| 342 |
+
# new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
|
| 343 |
+
# new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
|
| 344 |
+
# return new_data
|
| 345 |
+
|
| 346 |
+
def renorm_batched(data):
|
| 347 |
+
mins = torch.amin(data, (-2, -1))
|
| 348 |
+
mins = mins.unsqueeze(1).unsqueeze(2)
|
| 349 |
+
mins = mins.expand(data.shape[0], 192, 2048)
|
| 350 |
+
shifted = data - mins
|
| 351 |
+
maxs = torch.amax(shifted, (-2, -1))
|
| 352 |
+
maxs = maxs.unsqueeze(1).unsqueeze(2)
|
| 353 |
+
maxs = maxs.expand(data.shape[0], 192, 2048)
|
| 354 |
+
shifted = shifted/maxs
|
| 355 |
+
return shifted
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def transform_mask(data):
|
| 359 |
+
copy_data = deepcopy(data)
|
| 360 |
+
shift = copy_data - copy_data.min()
|
| 361 |
+
normalized_data = shift / shift.max()
|
| 362 |
+
new_data = np.zeros((3, data.shape[0], data.shape[1]))
|
| 363 |
+
for i in range(3):
|
| 364 |
+
new_data[i,:,:] = normalized_data
|
| 365 |
+
new_data = new_data.astype(np.float32)
|
| 366 |
+
return new_data
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
#Function to Convert to ONNX
|
| 370 |
+
def Convert_ONNX(model, saveloc, input_data_mock):
|
| 371 |
+
print("Saving to ONNX")
|
| 372 |
+
# set the model to inference mode
|
| 373 |
+
model.eval()
|
| 374 |
+
|
| 375 |
+
# Let's create a dummy input tensor
|
| 376 |
+
dummy_input = torch.autograd.Variable(input_data_mock)
|
| 377 |
+
|
| 378 |
+
# Export the model
|
| 379 |
+
torch.onnx.export(model, # model being run
|
| 380 |
+
dummy_input, # model input (or a tuple for multiple inputs)
|
| 381 |
+
saveloc, # where to save the model
|
| 382 |
+
input_names = ['modelInput'], # the model's input names
|
| 383 |
+
output_names = ['modelOutput'], # the model's output names
|
| 384 |
+
dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes
|
| 385 |
+
'modelOutput' : {0 : 'batch_size'}} )
|
| 386 |
+
print(" ")
|
| 387 |
+
print('Model has been converted to ONNX')
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from blimpy import Waterfall
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from sigpyproc.readers import FilReader
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def renorm_batched(data):
|
| 15 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
| 16 |
+
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
| 17 |
+
standardized_data = (data - mean) / std
|
| 18 |
+
return standardized_data
|
| 19 |
+
|
| 20 |
+
def transform_batched(data):
|
| 21 |
+
copy_data = data.detach().clone()
|
| 22 |
+
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
|
| 23 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
|
| 24 |
+
masks_rms = [-1, 5]
|
| 25 |
+
|
| 26 |
+
# Prepare the new_data tensor
|
| 27 |
+
num_masks = len(masks_rms) + 1
|
| 28 |
+
new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
|
| 29 |
+
|
| 30 |
+
# First layer: Apply renorm(log10(copy_data + epsilon))
|
| 31 |
+
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
|
| 32 |
+
for i, scale in enumerate(masks_rms, start=1):
|
| 33 |
+
copy_data = data.detach().clone()
|
| 34 |
+
|
| 35 |
+
# Apply masking based on the scale
|
| 36 |
+
if scale < 0:
|
| 37 |
+
ind = copy_data < abs(scale) * rms + mean
|
| 38 |
+
else:
|
| 39 |
+
ind = copy_data > scale * rms + mean
|
| 40 |
+
copy_data[ind] = 0
|
| 41 |
+
|
| 42 |
+
# Renormalize and log10 transform
|
| 43 |
+
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
|
| 44 |
+
|
| 45 |
+
# Convert to float32
|
| 46 |
+
new_data = new_data.type(torch.float32)
|
| 47 |
+
|
| 48 |
+
# Chunk along the last dimension and stack
|
| 49 |
+
slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
|
| 50 |
+
new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
|
| 51 |
+
new_data = torch.swapaxes(new_data, 0,1)
|
| 52 |
+
# Reshape into final format
|
| 53 |
+
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
|
| 54 |
+
return new_data
|
| 55 |
+
|
| 56 |
+
class preproc_flip(nn.Module):
|
| 57 |
+
def forward(self, x, flip=True):
|
| 58 |
+
template = transform_batched(torch.flip(x, dims = (-2,)))
|
| 59 |
+
return template
|
| 60 |
+
|
| 61 |
+
class preproc(nn.Module):
|
| 62 |
+
def forward(self, x, flip=True):
|
| 63 |
+
template = transform_batched(x)
|
| 64 |
+
return template
|
| 65 |
+
|
models/HITS-FEB-10.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ce8b602cef03e22c666cdea792741411f623fb5eb0a254ef1ffd9a32864d754
|
| 3 |
+
size 270858960
|
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739230556_9.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cc8774df726c0db3358653dc48894f70f2964cac7604dc5319b44f8f6340b71
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739230556_9.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739231399_1.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f66e781e7776a11067d0b103b36e9f42b7130d5297ab8936a814af659e125524
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739231399_1.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739231802_11.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9edde89b6a0526420dcff5e35516979a20f379fac7a2c98d59b9dae897bf4426
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739231802_11.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739234628_13.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:550f498a863202da1dad26237f7daf0a4f347e7fecde3650e9abee7611994078
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739234628_13.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739234628_14.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b49681cd9e55c1b3f81be0d2e9c04aacece284581b5b534cf8531ef5464b084
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739234628_14.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739235333_29.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04f284312e3c2d80e49d617bec875f1b6664bbe21e1418ba9fee8f0cafe70e24
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739235333_29.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739235841_12.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5b8008a6d13ef7b36fe03fd9dab1e58a6bbb770b433ffb77a3eb11a5033a09f
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739235841_12.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_50233055_1739232802_29.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2d0a49b0bf3825cfbbdbd91c465764f82b326e172c378685c9de87a44f1296e
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_50233055_1739232802_29.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_52111435_1739229641_28.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:caa8876c4c3185ad998f1a61a2b4176dec62d5ad38e3ab585aa9ff730848f83f
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_52111435_1739229641_28.png
ADDED
|
Git LFS Details
|
models/HITS-FEB-10/hit_52550001_1739233595_4.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a69ae7f971552525b8beac6adee0baa0eb066158fa73e2748265e7a803f8c9f
|
| 3 |
+
size 1572992
|
models/HITS-FEB-10/hit_52550001_1739233595_4.png
ADDED
|
Git LFS Details
|