CMuSeNet Training / Validation code and Synthetic IQ samples generator
Browse files- CMuSeNet_BIGRED.ipynb +1277 -0
- CMuSeNet_Indoor_OTA.ipynb +1658 -0
- CMuSeNet_Synthetic.ipynb +1241 -0
- CMuSeNet_Synthetic_IQ_Generator/README.txt +26 -0
- CMuSeNet_Synthetic_IQ_Generator/datagen.m +28 -0
- CMuSeNet_Synthetic_IQ_Generator/datagenTransmitter.m +64 -0
- CMuSeNet_Synthetic_IQ_Generator/datagenWideband.m +147 -0
CMuSeNet_BIGRED.ipynb
ADDED
|
@@ -0,0 +1,1277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b5007b71",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"### Initialization"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "3e6b1226",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"from pathlib import Path\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"from scipy.signal import welch\n",
|
| 21 |
+
"import torch\n",
|
| 22 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 23 |
+
"from tqdm import tqdm\n",
|
| 24 |
+
"import math\n",
|
| 25 |
+
"import json\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"# Constants\n",
|
| 28 |
+
"START_INDEX = 10 # Skip first few samples\n",
|
| 29 |
+
"SIGNAL_LENGTH = 1024 * 16\n",
|
| 30 |
+
"SAMPLE_RATE = 20e6\n",
|
| 31 |
+
"MASK_SIZE = 1024 * 16 # Mask size for segmentation\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"# Functions for Signal Processing\n",
|
| 34 |
+
"def load_real_data(sample_path):\n",
|
| 35 |
+
" \"\"\"\n",
|
| 36 |
+
" Load raw signal data from a .dat file.\n",
|
| 37 |
+
" \"\"\"\n",
|
| 38 |
+
" with open(sample_path, \"rb\") as f:\n",
|
| 39 |
+
" signal = np.fromfile(f, dtype=np.complex64)\n",
|
| 40 |
+
" return signal\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"def load_data(signal_id):\n",
|
| 43 |
+
" \"\"\"\n",
|
| 44 |
+
" Load signal data and its corresponding metadata.\n",
|
| 45 |
+
" \"\"\"\n",
|
| 46 |
+
" signal = load_real_data(signal_id)\n",
|
| 47 |
+
" metadata_file = signal_id.with_suffix(\".json\")\n",
|
| 48 |
+
" if metadata_file.exists():\n",
|
| 49 |
+
" with open(metadata_file, \"r\") as f:\n",
|
| 50 |
+
" metadata = json.load(f)\n",
|
| 51 |
+
" else:\n",
|
| 52 |
+
" raise FileNotFoundError(f\"Metadata file {metadata_file} not found for signal {signal_id}\")\n",
|
| 53 |
+
" return signal[START_INDEX:], metadata, metadata_file\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"def apply_psd(signal, Fs, NFFT):\n",
|
| 56 |
+
" \"\"\"\n",
|
| 57 |
+
" Calculate the PSD and corresponding frequencies using Welch's method.\n",
|
| 58 |
+
" \"\"\"\n",
|
| 59 |
+
" freqs, psd = welch(signal, fs=Fs, nfft=NFFT, return_onesided=False)\n",
|
| 60 |
+
" psd = np.fft.fftshift(psd)\n",
|
| 61 |
+
" freqs = np.fft.fftshift(freqs)\n",
|
| 62 |
+
" return psd, freqs\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"def calculate_fft(signal):\n",
|
| 65 |
+
" \"\"\"\n",
|
| 66 |
+
" Calculate the FFT of the signal and return real and imaginary parts as separate channels.\n",
|
| 67 |
+
" \"\"\"\n",
|
| 68 |
+
" signal = signal[:SIGNAL_LENGTH]\n",
|
| 69 |
+
" signal = np.fft.fft(signal)\n",
|
| 70 |
+
" signal = np.fft.fftshift(signal)\n",
|
| 71 |
+
" signal /= np.max(np.abs(signal))\n",
|
| 72 |
+
" return signal"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"id": "440b802c",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"source": [
|
| 80 |
+
"### Data Loading"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"id": "31bc3770",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"# Dataset Class\n",
|
| 91 |
+
"class WidebandSignalDataset(Dataset):\n",
|
| 92 |
+
" def __init__(self, signal_ids, mask_size=1024 * 16):\n",
|
| 93 |
+
" \"\"\"\n",
|
| 94 |
+
" Initialize the dataset with signal IDs and the specified mask size.\n",
|
| 95 |
+
" \"\"\"\n",
|
| 96 |
+
" self.mask_size = mask_size\n",
|
| 97 |
+
" self.signal_ids = signal_ids\n",
|
| 98 |
+
" self.loaded_data = [self.process_signal(signal_id) for signal_id in tqdm(self.signal_ids)]\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" def __len__(self):\n",
|
| 101 |
+
" return len(self.signal_ids)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" def __getitem__(self, index):\n",
|
| 104 |
+
" return self.loaded_data[index]\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" def process_signal(self, signal_id):\n",
|
| 107 |
+
" signal, metadata, _ = load_data(signal_id)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
" # Ensure signal length matches SIGNAL_LENGTH\n",
|
| 110 |
+
" if len(signal) < SIGNAL_LENGTH:\n",
|
| 111 |
+
" # Pad with zeros if the signal is shorter\n",
|
| 112 |
+
" signal = np.pad(signal, (0, SIGNAL_LENGTH - len(signal)), mode='constant')\n",
|
| 113 |
+
" elif len(signal) > SIGNAL_LENGTH:\n",
|
| 114 |
+
" # Truncate if the signal is longer\n",
|
| 115 |
+
" signal = signal[:SIGNAL_LENGTH]\n",
|
| 116 |
+
"\n",
|
| 117 |
+
" # Apply FFT\n",
|
| 118 |
+
" signal = np.fft.fft(signal)\n",
|
| 119 |
+
" signal = np.fft.fftshift(signal)\n",
|
| 120 |
+
" signal /= np.max(np.abs(signal)) # Normalize\n",
|
| 121 |
+
" complex_signal = torch.from_numpy(signal).type(torch.complex64).unsqueeze(0) # Add channel dimension\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" # Create mask with fixed size\n",
|
| 124 |
+
" masks = torch.zeros(self.mask_size, dtype=torch.float32)\n",
|
| 125 |
+
" scale_ratio = self.mask_size / SAMPLE_RATE\n",
|
| 126 |
+
" scaled_metadata = process_metadata(metadata)\n",
|
| 127 |
+
" for meta in scaled_metadata:\n",
|
| 128 |
+
" f1, f2 = meta[\"position\"]\n",
|
| 129 |
+
" x1 = int(math.floor(f1 * scale_ratio))\n",
|
| 130 |
+
" x2 = int(math.ceil(f2 * scale_ratio))\n",
|
| 131 |
+
" masks[x1:x2] = 1\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" return complex_signal, masks\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"def process_metadata(metadata):\n",
|
| 138 |
+
" \"\"\"\n",
|
| 139 |
+
" Scale metadata to the dataset's frequency and bandwidth ranges.\n",
|
| 140 |
+
" \"\"\"\n",
|
| 141 |
+
" scaled_metadata = [\n",
|
| 142 |
+
" {\n",
|
| 143 |
+
" \"position\": (\n",
|
| 144 |
+
" math.floor((SAMPLE_RATE / 2 + i[\"fc\"] - i[\"bw\"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE),\n",
|
| 145 |
+
" math.ceil((SAMPLE_RATE / 2 + i[\"fc\"] + i[\"bw\"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE)\n",
|
| 146 |
+
" ),\n",
|
| 147 |
+
" \"snr\": 1, # Placeholder value\n",
|
| 148 |
+
" \"bw\": i[\"bw\"],\n",
|
| 149 |
+
" \"num\": len(metadata),\n",
|
| 150 |
+
" \"esn0\": 1, # Placeholder value\n",
|
| 151 |
+
" }\n",
|
| 152 |
+
" for i in metadata\n",
|
| 153 |
+
" ]\n",
|
| 154 |
+
" return scaled_metadata\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# Dataset Splitting and Initialization\n",
|
| 157 |
+
"NEW_DATA_DIR = Path(\"/data/bigred/ofh/0\")\n",
|
| 158 |
+
"def get_real_signals(freq_directory):\n",
|
| 159 |
+
" return list(freq_directory.rglob(\"*.dat\"))\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"signal_dirs = get_real_signals(NEW_DATA_DIR)\n",
|
| 162 |
+
"total_signals = len(signal_dirs)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"train_split = int(0.80 * total_signals)\n",
|
| 165 |
+
"validation_split = int(0.90 * total_signals)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"train, validation, test = (\n",
|
| 168 |
+
" signal_dirs[:train_split],\n",
|
| 169 |
+
" signal_dirs[train_split:validation_split],\n",
|
| 170 |
+
" signal_dirs[validation_split:]\n",
|
| 171 |
+
")\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"print(f\"Train set size: {len(train)}\")\n",
|
| 174 |
+
"print(f\"Validation set size: {len(validation)}\")\n",
|
| 175 |
+
"print(f\"Test set size: {len(test)}\")"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "code",
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"id": "f5305642",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [],
|
| 184 |
+
"source": [
|
| 185 |
+
"# Data Loaders\n",
|
| 186 |
+
"BATCH_SIZE = 64\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"train_dataset = WidebandSignalDataset(signal_ids=train)\n",
|
| 189 |
+
"validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
|
| 190 |
+
"test_dataset = WidebandSignalDataset(signal_ids=test)"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "code",
|
| 195 |
+
"execution_count": null,
|
| 196 |
+
"id": "54a4f325",
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"outputs": [],
|
| 199 |
+
"source": [
|
| 200 |
+
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
| 201 |
+
"valid_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
|
| 202 |
+
"test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "markdown",
|
| 207 |
+
"id": "3893c583",
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"source": [
|
| 210 |
+
"### CV-ResNet-18"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "code",
|
| 215 |
+
"execution_count": null,
|
| 216 |
+
"id": "bc2001c4",
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"import torch\n",
|
| 221 |
+
"import torch.nn as nn\n",
|
| 222 |
+
"import complexPyTorch.complexLayers as cplx\n",
|
| 223 |
+
"from typing import Optional, Callable, Type, Union, List\n",
|
| 224 |
+
"import torch.nn.functional as F\n",
|
| 225 |
+
"from torch import Tensor\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 228 |
+
" \"\"\"3x3 convolution with padding\"\"\"\n",
|
| 229 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 232 |
+
" \"\"\"1x1 convolution\"\"\"\n",
|
| 233 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"class BasicBlock(nn.Module):\n",
|
| 236 |
+
" expansion = 1\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" def __init__(\n",
|
| 239 |
+
" self,\n",
|
| 240 |
+
" inplanes: int,\n",
|
| 241 |
+
" planes: int,\n",
|
| 242 |
+
" stride: int = 1,\n",
|
| 243 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 244 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 245 |
+
" ) -> None:\n",
|
| 246 |
+
" super(BasicBlock, self).__init__()\n",
|
| 247 |
+
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
|
| 248 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 249 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 250 |
+
" self.conv2 = conv3x3(planes, planes)\n",
|
| 251 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 252 |
+
" self.downsample = downsample\n",
|
| 253 |
+
" self.stride = stride\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 256 |
+
" identity = x\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" out = self.conv1(x)\n",
|
| 259 |
+
" out = self.bn1(out)\n",
|
| 260 |
+
" out = self.relu(out)\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" out = self.conv2(out)\n",
|
| 263 |
+
" out = self.bn2(out)\n",
|
| 264 |
+
"\n",
|
| 265 |
+
" if self.downsample is not None:\n",
|
| 266 |
+
" identity = self.downsample(x)\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" out += identity\n",
|
| 269 |
+
" out = self.relu(out)\n",
|
| 270 |
+
"\n",
|
| 271 |
+
" return out\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"class Bottleneck(nn.Module):\n",
|
| 274 |
+
" expansion = 4\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" def __init__(\n",
|
| 277 |
+
" self,\n",
|
| 278 |
+
" inplanes: int,\n",
|
| 279 |
+
" planes: int,\n",
|
| 280 |
+
" stride: int = 1,\n",
|
| 281 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 282 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 283 |
+
" ) -> None:\n",
|
| 284 |
+
" super(Bottleneck, self).__init__()\n",
|
| 285 |
+
" self.conv1 = conv1x1(inplanes, planes)\n",
|
| 286 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 287 |
+
" self.conv2 = conv3x3(planes, planes, stride)\n",
|
| 288 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 289 |
+
" self.conv3 = conv1x1(planes, planes * self.expansion)\n",
|
| 290 |
+
" self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
|
| 291 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 292 |
+
" self.downsample = downsample\n",
|
| 293 |
+
" self.stride = stride\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 296 |
+
" identity = x\n",
|
| 297 |
+
"\n",
|
| 298 |
+
" out = self.conv1(x)\n",
|
| 299 |
+
" out = self.bn1(out)\n",
|
| 300 |
+
" out = self.relu(out)\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" out = self.conv2(out)\n",
|
| 303 |
+
" out = self.bn2(out)\n",
|
| 304 |
+
" out = self.relu(out)\n",
|
| 305 |
+
"\n",
|
| 306 |
+
" out = self.conv3(out)\n",
|
| 307 |
+
" out = self.bn3(out)\n",
|
| 308 |
+
"\n",
|
| 309 |
+
" if self.downsample is not None:\n",
|
| 310 |
+
" identity = self.downsample(x)\n",
|
| 311 |
+
"\n",
|
| 312 |
+
" out += identity\n",
|
| 313 |
+
" out = self.relu(out)\n",
|
| 314 |
+
"\n",
|
| 315 |
+
" return out\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"class ComplexResNet(nn.Module):\n",
|
| 318 |
+
" def __init__(\n",
|
| 319 |
+
" self,\n",
|
| 320 |
+
" block: Type[Union[BasicBlock, Bottleneck]],\n",
|
| 321 |
+
" layers: List[int],\n",
|
| 322 |
+
" num_classes: int = SIGNAL_LENGTH,\n",
|
| 323 |
+
" zero_init_residual: bool = False,\n",
|
| 324 |
+
" groups: int = 1,\n",
|
| 325 |
+
" width_per_group: int = 64,\n",
|
| 326 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 327 |
+
" ) -> None:\n",
|
| 328 |
+
" super(ComplexResNet, self).__init__()\n",
|
| 329 |
+
" if norm_layer is None:\n",
|
| 330 |
+
" norm_layer = cplx.ComplexBatchNorm2d\n",
|
| 331 |
+
" self._norm_layer = norm_layer\n",
|
| 332 |
+
"\n",
|
| 333 |
+
" self.inplanes = 64\n",
|
| 334 |
+
" self.dilation = 1\n",
|
| 335 |
+
"\n",
|
| 336 |
+
" self.groups = groups\n",
|
| 337 |
+
" self.base_width = width_per_group\n",
|
| 338 |
+
" self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
|
| 339 |
+
" self.bn1 = norm_layer(self.inplanes)\n",
|
| 340 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 341 |
+
" self.layer1 = self._make_layer(block, 64, layers[0])\n",
|
| 342 |
+
" self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
|
| 343 |
+
" self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
|
| 344 |
+
" self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
|
| 345 |
+
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
|
| 346 |
+
" self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
|
| 347 |
+
" self.sigmoid = cplx.ComplexSigmoid()\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
|
| 350 |
+
" norm_layer = self._norm_layer\n",
|
| 351 |
+
" downsample = None\n",
|
| 352 |
+
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
|
| 353 |
+
" downsample = nn.Sequential(\n",
|
| 354 |
+
" conv1x1(self.inplanes, planes * block.expansion, stride),\n",
|
| 355 |
+
" norm_layer(planes * block.expansion),\n",
|
| 356 |
+
" )\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" layers = []\n",
|
| 359 |
+
" layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
|
| 360 |
+
" self.inplanes = planes * block.expansion\n",
|
| 361 |
+
" for _ in range(1, blocks):\n",
|
| 362 |
+
" layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" return nn.Sequential(*layers)\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" def _forward_impl(self, x: Tensor) -> Tensor:\n",
|
| 367 |
+
" x = self.conv1(x)\n",
|
| 368 |
+
" x = self.bn1(x)\n",
|
| 369 |
+
" x = self.relu(x)\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" x = self.layer1(x)\n",
|
| 372 |
+
" x = self.layer2(x)\n",
|
| 373 |
+
" x = self.layer3(x)\n",
|
| 374 |
+
" x = self.layer4(x)\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" x = self.avgpool(x)\n",
|
| 377 |
+
" x = torch.flatten(x, 1)\n",
|
| 378 |
+
" x = self.fc(x)\n",
|
| 379 |
+
" x = self.sigmoid(x)\n",
|
| 380 |
+
" return x\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 383 |
+
" return self._forward_impl(x)\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"def ComplexResNet18():\n",
|
| 386 |
+
" return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"# Create the model instance\n",
|
| 389 |
+
"model = ComplexResNet18()\n",
|
| 390 |
+
"print(model)\n"
|
| 391 |
+
]
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"cell_type": "markdown",
|
| 395 |
+
"id": "9a8e09e4",
|
| 396 |
+
"metadata": {},
|
| 397 |
+
"source": [
|
| 398 |
+
"### Early Stop"
|
| 399 |
+
]
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"cell_type": "code",
|
| 403 |
+
"execution_count": null,
|
| 404 |
+
"id": "24f79a24",
|
| 405 |
+
"metadata": {},
|
| 406 |
+
"outputs": [],
|
| 407 |
+
"source": [
|
| 408 |
+
"import os\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"class EarlyStopping:\n",
|
| 411 |
+
" def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./path/to/model/save'):\n",
|
| 412 |
+
" self.patience = patience\n",
|
| 413 |
+
" self.verbose = verbose\n",
|
| 414 |
+
" self.delta = delta\n",
|
| 415 |
+
" self.counter = 0\n",
|
| 416 |
+
" self.best_score = None\n",
|
| 417 |
+
" self.early_stop = False\n",
|
| 418 |
+
" self.val_loss_min = float('inf')\n",
|
| 419 |
+
" self.best_model = None\n",
|
| 420 |
+
" self.save_path = save_path\n",
|
| 421 |
+
" os.makedirs(save_path, exist_ok=True)\n",
|
| 422 |
+
" \n",
|
| 423 |
+
" def __call__(self, val_loss, model):\n",
|
| 424 |
+
" score = -val_loss\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" if self.best_score is None:\n",
|
| 427 |
+
" self.best_score = score\n",
|
| 428 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 429 |
+
" elif score < self.best_score + self.delta:\n",
|
| 430 |
+
" self.counter += 1\n",
|
| 431 |
+
" if self.verbose:\n",
|
| 432 |
+
" print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
|
| 433 |
+
" if self.counter >= self.patience:\n",
|
| 434 |
+
" self.early_stop = True\n",
|
| 435 |
+
" else:\n",
|
| 436 |
+
" self.best_score = score\n",
|
| 437 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 438 |
+
" self.counter = 0\n",
|
| 439 |
+
"\n",
|
| 440 |
+
" def save_checkpoint(self, val_loss, model):\n",
|
| 441 |
+
" if self.verbose:\n",
|
| 442 |
+
" print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
|
| 443 |
+
" self.val_loss_min = val_loss\n",
|
| 444 |
+
" self.best_model = model.state_dict()\n",
|
| 445 |
+
" save_path = os.path.join(self.save_path, 'best_model.pth')\n",
|
| 446 |
+
" torch.save(self.best_model, save_path)"
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"cell_type": "markdown",
|
| 451 |
+
"id": "6c3fda74",
|
| 452 |
+
"metadata": {},
|
| 453 |
+
"source": [
|
| 454 |
+
"### Focal loss and reshape"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"cell_type": "code",
|
| 459 |
+
"execution_count": null,
|
| 460 |
+
"id": "5fcf91db",
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"outputs": [],
|
| 463 |
+
"source": [
|
| 464 |
+
"class ComplexFocalLoss(nn.Module):\n",
|
| 465 |
+
" def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
|
| 466 |
+
" super(ComplexFocalLoss, self).__init__()\n",
|
| 467 |
+
" self.alpha = alpha\n",
|
| 468 |
+
" self.gamma = gamma\n",
|
| 469 |
+
" self.reduction = reduction\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" def forward(self, inputs, targets):\n",
|
| 472 |
+
" real_inputs = inputs.real\n",
|
| 473 |
+
" imag_inputs = inputs.imag\n",
|
| 474 |
+
" \n",
|
| 475 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
|
| 476 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
|
| 477 |
+
" \n",
|
| 478 |
+
" real_pt = torch.exp(-real_BCE_loss)\n",
|
| 479 |
+
" imag_pt = torch.exp(-imag_BCE_loss)\n",
|
| 480 |
+
" \n",
|
| 481 |
+
" real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
|
| 482 |
+
" imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" if self.reduction == 'mean':\n",
|
| 485 |
+
" return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
|
| 486 |
+
" elif self.reduction == 'sum':\n",
|
| 487 |
+
" return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
|
| 488 |
+
" else:\n",
|
| 489 |
+
" return real_F_loss + imag_F_loss\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"# Update the IoU calculation to handle complex values\n",
|
| 492 |
+
"def calculate_iou(pred, target, threshold=0.5):\n",
|
| 493 |
+
" real_pred = (pred.real > threshold).float()\n",
|
| 494 |
+
" imag_pred = (pred.imag > threshold).float()\n",
|
| 495 |
+
" \n",
|
| 496 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
|
| 497 |
+
" \n",
|
| 498 |
+
" intersection = (combined_pred * target).sum(dim=1)\n",
|
| 499 |
+
" union = (combined_pred + target).sum(dim=1) - intersection\n",
|
| 500 |
+
" iou = (intersection / union).mean().item()\n",
|
| 501 |
+
" return iou\n",
|
| 502 |
+
"def reshape_to_2d(data):\n",
|
| 503 |
+
" return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
|
| 504 |
+
]
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"cell_type": "markdown",
|
| 508 |
+
"id": "c97635b0",
|
| 509 |
+
"metadata": {},
|
| 510 |
+
"source": [
|
| 511 |
+
"### BCE Loss"
|
| 512 |
+
]
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"cell_type": "code",
|
| 516 |
+
"execution_count": null,
|
| 517 |
+
"id": "2e8b2892",
|
| 518 |
+
"metadata": {},
|
| 519 |
+
"outputs": [],
|
| 520 |
+
"source": [
|
| 521 |
+
"# CV BCE Loss Function Definition\n",
|
| 522 |
+
"class ComplexValuedBCELoss(nn.Module):\n",
|
| 523 |
+
" def __init__(self, reduction='mean'):\n",
|
| 524 |
+
" super(ComplexValuedBCELoss, self).__init__()\n",
|
| 525 |
+
" self.reduction = reduction\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" def forward(self, inputs, targets):\n",
|
| 528 |
+
" real_inputs = inputs.real\n",
|
| 529 |
+
" imag_inputs = inputs.imag\n",
|
| 530 |
+
"\n",
|
| 531 |
+
" # Calculate binary cross-entropy for both real and imaginary parts\n",
|
| 532 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
|
| 533 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
|
| 534 |
+
" \n",
|
| 535 |
+
" # Combine the losses (you can adjust the weighting if necessary)\n",
|
| 536 |
+
" combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
|
| 537 |
+
" return combined_BCE_loss"
|
| 538 |
+
]
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"cell_type": "markdown",
|
| 542 |
+
"id": "64f4063c",
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"source": [
|
| 545 |
+
"### Training from scratch"
|
| 546 |
+
]
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"cell_type": "code",
|
| 550 |
+
"execution_count": null,
|
| 551 |
+
"id": "66825110",
|
| 552 |
+
"metadata": {
|
| 553 |
+
"scrolled": false
|
| 554 |
+
},
|
| 555 |
+
"outputs": [],
|
| 556 |
+
"source": [
|
| 557 |
+
"import time\n",
|
| 558 |
+
"device=\"cuda\"\n",
|
| 559 |
+
"def validate_model(model, valid_loader, criterion):\n",
|
| 560 |
+
" model.eval()\n",
|
| 561 |
+
" running_loss = 0.0\n",
|
| 562 |
+
" iou_scores = []\n",
|
| 563 |
+
" total_correct = 0\n",
|
| 564 |
+
" total_samples = 0\n",
|
| 565 |
+
"\n",
|
| 566 |
+
" with torch.no_grad():\n",
|
| 567 |
+
" for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
|
| 568 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 569 |
+
" masks = masks.to(device)\n",
|
| 570 |
+
" outputs = model(inputs)\n",
|
| 571 |
+
" loss = criterion(outputs, masks)\n",
|
| 572 |
+
" running_loss += loss.item()\n",
|
| 573 |
+
"\n",
|
| 574 |
+
" # Calculate IoU\n",
|
| 575 |
+
" iou = calculate_iou(outputs, masks, threshold=0.5)\n",
|
| 576 |
+
" iou_scores.append(iou)\n",
|
| 577 |
+
" \n",
|
| 578 |
+
" # Calculate accuracy\n",
|
| 579 |
+
" preds = (outputs.real > 0.5).float()\n",
|
| 580 |
+
" correct = (preds == masks).float().sum()\n",
|
| 581 |
+
" total_correct += correct.item()\n",
|
| 582 |
+
" total_samples += masks.numel()\n",
|
| 583 |
+
"\n",
|
| 584 |
+
" val_loss = running_loss / len(valid_loader)\n",
|
| 585 |
+
" mean_iou = sum(iou_scores) / len(iou_scores)\n",
|
| 586 |
+
" accuracy = total_correct / total_samples * 100\n",
|
| 587 |
+
"\n",
|
| 588 |
+
" print(f'Validation Loss: {val_loss:.6f}')\n",
|
| 589 |
+
" print(f'Validation Accuracy: {accuracy:.2f}%')\n",
|
| 590 |
+
"\n",
|
| 591 |
+
" return val_loss, accuracy\n",
|
| 592 |
+
"\n",
|
| 593 |
+
"def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.0001], num_epochs=50, patience=5):\n",
|
| 594 |
+
" train_losses = []\n",
|
| 595 |
+
" val_losses = []\n",
|
| 596 |
+
" val_accuracies = []\n",
|
| 597 |
+
" epoch_durations = []\n",
|
| 598 |
+
" \n",
|
| 599 |
+
" current_lr = initial_lr\n",
|
| 600 |
+
" for lr in lr_steps:\n",
|
| 601 |
+
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
| 602 |
+
" early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
|
| 603 |
+
" print(\"Current learning rate: \", lr)\n",
|
| 604 |
+
" for epoch in range(num_epochs):\n",
|
| 605 |
+
" epoch_start_time = time.time()\n",
|
| 606 |
+
" \n",
|
| 607 |
+
" model.train()\n",
|
| 608 |
+
" running_loss = 0.0\n",
|
| 609 |
+
" for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
|
| 610 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 611 |
+
" masks = masks.to(device)\n",
|
| 612 |
+
" outputs = model(inputs)\n",
|
| 613 |
+
" loss = criterion(outputs, masks)\n",
|
| 614 |
+
"\n",
|
| 615 |
+
" optimizer.zero_grad()\n",
|
| 616 |
+
" loss.backward()\n",
|
| 617 |
+
" optimizer.step()\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" running_loss += loss.item()\n",
|
| 620 |
+
"\n",
|
| 621 |
+
" epoch_loss = running_loss / len(train_loader)\n",
|
| 622 |
+
" train_losses.append(epoch_loss)\n",
|
| 623 |
+
" print(f\"Training Loss: {epoch_loss:.6f}\")\n",
|
| 624 |
+
" val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
|
| 625 |
+
" val_losses.append(val_loss)\n",
|
| 626 |
+
" val_accuracies.append(val_accuracy)\n",
|
| 627 |
+
" early_stopping(val_loss, model)\n",
|
| 628 |
+
"\n",
|
| 629 |
+
" if early_stopping.early_stop:\n",
|
| 630 |
+
" print(\"Early stopping triggered\")\n",
|
| 631 |
+
" break\n",
|
| 632 |
+
"\n",
|
| 633 |
+
" epoch_duration = time.time() - epoch_start_time\n",
|
| 634 |
+
" epoch_durations.append(epoch_duration)\n",
|
| 635 |
+
" if early_stopping.best_model is not None:\n",
|
| 636 |
+
" print(f\"Loading best model from lr {lr}\")\n",
|
| 637 |
+
" model.load_state_dict(early_stopping.best_model)\n",
|
| 638 |
+
" \n",
|
| 639 |
+
" print(\"Training completed.\")\n",
|
| 640 |
+
" print(\"Epoch durations:\", epoch_durations)\n",
|
| 641 |
+
" return model, train_losses, val_losses, val_accuracies, epoch_durations"
|
| 642 |
+
]
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"cell_type": "code",
|
| 646 |
+
"execution_count": null,
|
| 647 |
+
"id": "621d28b3",
|
| 648 |
+
"metadata": {
|
| 649 |
+
"scrolled": false
|
| 650 |
+
},
|
| 651 |
+
"outputs": [],
|
| 652 |
+
"source": [
|
| 653 |
+
"# Initialize and train the ResNet-18 model\n",
|
| 654 |
+
"model = ComplexResNet18().to(device)\n",
|
| 655 |
+
"criterion = ComplexFocalLoss()\n",
|
| 656 |
+
"\n",
|
| 657 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
|
| 658 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 659 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 660 |
+
]
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"cell_type": "markdown",
|
| 664 |
+
"id": "3838c1bc",
|
| 665 |
+
"metadata": {},
|
| 666 |
+
"source": [
|
| 667 |
+
"### Transfer Learning Load pretrained model"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"cell_type": "code",
|
| 672 |
+
"execution_count": null,
|
| 673 |
+
"id": "ac763e75",
|
| 674 |
+
"metadata": {},
|
| 675 |
+
"outputs": [],
|
| 676 |
+
"source": [
|
| 677 |
+
"# Path to the pre-trained model weights\n",
|
| 678 |
+
"pretrained_model_path = \"path/to/model/save.pth\" #Change this model to trained model\n",
|
| 679 |
+
"device=\"cuda\"\n",
|
| 680 |
+
"# Initialize the model architecture\n",
|
| 681 |
+
"model = ComplexResNet18().to(device)\n",
|
| 682 |
+
"\n",
|
| 683 |
+
"# Load the pre-trained weights\n",
|
| 684 |
+
"checkpoint = torch.load(pretrained_model_path)\n",
|
| 685 |
+
"model.load_state_dict(checkpoint, strict=False)\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"# Set all layers as trainable (if needed)\n",
|
| 688 |
+
"for param in model.parameters():\n",
|
| 689 |
+
" param.requires_grad = True"
|
| 690 |
+
]
|
| 691 |
+
},
|
| 692 |
+
{
|
| 693 |
+
"cell_type": "code",
|
| 694 |
+
"execution_count": null,
|
| 695 |
+
"id": "1f877827",
|
| 696 |
+
"metadata": {
|
| 697 |
+
"scrolled": false
|
| 698 |
+
},
|
| 699 |
+
"outputs": [],
|
| 700 |
+
"source": [
|
| 701 |
+
"# Define a new criterion and optimizer for fine-tuning\n",
|
| 702 |
+
"# You may select between Focal Loss or BCE as your criterion\n",
|
| 703 |
+
"#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
|
| 704 |
+
"criterion = ComplexFocalLoss()\n",
|
| 705 |
+
"# Use a smaller learning rate for fine-tuning\n",
|
| 706 |
+
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
|
| 707 |
+
"\n",
|
| 708 |
+
"# Train the model (fine-tuning)\n",
|
| 709 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
|
| 710 |
+
" model, train_loader, valid_loader, criterion,\n",
|
| 711 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
|
| 712 |
+
")\n",
|
| 713 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 714 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 715 |
+
]
|
| 716 |
+
},
|
| 717 |
+
{
|
| 718 |
+
"cell_type": "markdown",
|
| 719 |
+
"id": "f3784964",
|
| 720 |
+
"metadata": {},
|
| 721 |
+
"source": [
|
| 722 |
+
"### Plot Result and save the figures and json"
|
| 723 |
+
]
|
| 724 |
+
},
|
| 725 |
+
{
|
| 726 |
+
"cell_type": "code",
|
| 727 |
+
"execution_count": null,
|
| 728 |
+
"id": "67a52e13",
|
| 729 |
+
"metadata": {
|
| 730 |
+
"scrolled": false
|
| 731 |
+
},
|
| 732 |
+
"outputs": [],
|
| 733 |
+
"source": [
|
| 734 |
+
"import os\n",
|
| 735 |
+
"import json\n",
|
| 736 |
+
"import matplotlib.pyplot as plt\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"# Define save directory\n",
|
| 739 |
+
"save_dir = 'CMuSeNet_results/segmentation'\n",
|
| 740 |
+
"\n",
|
| 741 |
+
"# Create the directory if it doesn't exist\n",
|
| 742 |
+
"os.makedirs(save_dir, exist_ok=True)\n",
|
| 743 |
+
"\n",
|
| 744 |
+
"# Plot training loss\n",
|
| 745 |
+
"plt.figure()\n",
|
| 746 |
+
"plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')\n",
|
| 747 |
+
"plt.title('Training Loss')\n",
|
| 748 |
+
"plt.xlabel('Epoch')\n",
|
| 749 |
+
"plt.ylabel('Loss')\n",
|
| 750 |
+
"plt.legend()\n",
|
| 751 |
+
"\n",
|
| 752 |
+
"# Save the training loss figure as PNG and SVG\n",
|
| 753 |
+
"plt.savefig(os.path.join(save_dir, 'training_loss.png'))\n",
|
| 754 |
+
"plt.savefig(os.path.join(save_dir, 'training_loss.svg'))\n",
|
| 755 |
+
"\n",
|
| 756 |
+
"# Show the training loss plot\n",
|
| 757 |
+
"plt.show()\n",
|
| 758 |
+
"\n",
|
| 759 |
+
"# Plot validation accuracy\n",
|
| 760 |
+
"plt.figure()\n",
|
| 761 |
+
"plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='green')\n",
|
| 762 |
+
"plt.title('Validation Accuracy')\n",
|
| 763 |
+
"plt.xlabel('Epoch')\n",
|
| 764 |
+
"plt.ylabel('Accuracy')\n",
|
| 765 |
+
"plt.legend()\n",
|
| 766 |
+
"\n",
|
| 767 |
+
"# Save the validation accuracy figure as PNG and SVG\n",
|
| 768 |
+
"plt.savefig(os.path.join(save_dir, 'validation_accuracy.png'))\n",
|
| 769 |
+
"plt.savefig(os.path.join(save_dir, 'validation_accuracy.svg'))\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"# Show the validation accuracy plot\n",
|
| 772 |
+
"plt.show()\n",
|
| 773 |
+
"\n",
|
| 774 |
+
"# Save the actual data to a JSON file\n",
|
| 775 |
+
"results = {\n",
|
| 776 |
+
" \"train_losses\": train_losses,\n",
|
| 777 |
+
" \"val_accuracies\": val_accuracies\n",
|
| 778 |
+
"}\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"# Save JSON file\n",
|
| 781 |
+
"with open(os.path.join(save_dir, 'training_validation_results.json'), 'w') as f:\n",
|
| 782 |
+
" json.dump(results, f)\n"
|
| 783 |
+
]
|
| 784 |
+
},
|
| 785 |
+
{
|
| 786 |
+
"cell_type": "markdown",
|
| 787 |
+
"id": "222069ae",
|
| 788 |
+
"metadata": {},
|
| 789 |
+
"source": [
|
| 790 |
+
"### BIG-RED Evaluation (Over entire dataset)"
|
| 791 |
+
]
|
| 792 |
+
},
|
| 793 |
+
{
|
| 794 |
+
"cell_type": "code",
|
| 795 |
+
"execution_count": null,
|
| 796 |
+
"id": "6b178984",
|
| 797 |
+
"metadata": {},
|
| 798 |
+
"outputs": [],
|
| 799 |
+
"source": [
|
| 800 |
+
"import torch\n",
|
| 801 |
+
"from torch.utils.data import DataLoader\n",
|
| 802 |
+
"from tqdm import tqdm\n",
|
| 803 |
+
"# Create a DataLoader for the entire dataset\n",
|
| 804 |
+
"BATCH_SIZE = 64 # Adjust based on available memory\n",
|
| 805 |
+
"entire_dataset = WidebandSignalDataset(signal_ids=signal_dirs) # Use all signals\n",
|
| 806 |
+
"entire_loader = DataLoader(entire_dataset, batch_size=BATCH_SIZE, shuffle=False)"
|
| 807 |
+
]
|
| 808 |
+
},
|
| 809 |
+
{
|
| 810 |
+
"cell_type": "code",
|
| 811 |
+
"execution_count": null,
|
| 812 |
+
"id": "2e6be59a",
|
| 813 |
+
"metadata": {},
|
| 814 |
+
"outputs": [],
|
| 815 |
+
"source": [
|
| 816 |
+
"# Path to the pre-trained model weights\n",
|
| 817 |
+
"pretrained_model_path = \"path/to/model/pretrained\" \n",
|
| 818 |
+
"device = \"cuda\" \n",
|
| 819 |
+
"\n",
|
| 820 |
+
"# Initialize the model architecture\n",
|
| 821 |
+
"model = ComplexResNet18().to(device)\n",
|
| 822 |
+
"\n",
|
| 823 |
+
"# Load the pre-trained weights\n",
|
| 824 |
+
"checkpoint = torch.load(pretrained_model_path, map_location=device)\n",
|
| 825 |
+
"model.load_state_dict(checkpoint, strict=False)\n",
|
| 826 |
+
"model.eval()\n",
|
| 827 |
+
"\n",
|
| 828 |
+
"# Function to evaluate accuracy\n",
|
| 829 |
+
"def evaluate_accuracy(model, data_loader):\n",
|
| 830 |
+
" total_correct = 0\n",
|
| 831 |
+
" total_samples = 0\n",
|
| 832 |
+
"\n",
|
| 833 |
+
" with torch.no_grad():\n",
|
| 834 |
+
" for inputs, masks in tqdm(data_loader, desc=\"Evaluating on Entire Dataset\"):\n",
|
| 835 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 836 |
+
" masks = masks.to(device)\n",
|
| 837 |
+
"\n",
|
| 838 |
+
" outputs = model(inputs)\n",
|
| 839 |
+
" preds = (outputs.real > 0.5).float()\n",
|
| 840 |
+
"\n",
|
| 841 |
+
" correct = (preds == masks).float().sum()\n",
|
| 842 |
+
" total_correct += correct.item()\n",
|
| 843 |
+
" total_samples += masks.numel()\n",
|
| 844 |
+
"\n",
|
| 845 |
+
" accuracy = total_correct / total_samples * 100\n",
|
| 846 |
+
" print(f\"Overall Accuracy on Entire Dataset: {accuracy:.2f}%\")\n",
|
| 847 |
+
" return accuracy\n",
|
| 848 |
+
"\n",
|
| 849 |
+
"# Run the evaluation\n",
|
| 850 |
+
"overall_accuracy = evaluate_accuracy(model, entire_loader)"
|
| 851 |
+
]
|
| 852 |
+
},
|
| 853 |
+
{
|
| 854 |
+
"cell_type": "markdown",
|
| 855 |
+
"id": "2a5a21b4",
|
| 856 |
+
"metadata": {},
|
| 857 |
+
"source": [
|
| 858 |
+
"### Function definitions"
|
| 859 |
+
]
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"cell_type": "code",
|
| 863 |
+
"execution_count": null,
|
| 864 |
+
"id": "b223d9b5",
|
| 865 |
+
"metadata": {},
|
| 866 |
+
"outputs": [],
|
| 867 |
+
"source": [
|
| 868 |
+
"import torch\n",
|
| 869 |
+
"from tqdm import tqdm\n",
|
| 870 |
+
"import numpy as np\n",
|
| 871 |
+
"from collections import defaultdict\n",
|
| 872 |
+
"import torch.nn.functional as F\n",
|
| 873 |
+
"from scipy.optimize import linear_sum_assignment\n",
|
| 874 |
+
"from torch.utils.data import ConcatDataset"
|
| 875 |
+
]
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
"cell_type": "code",
|
| 879 |
+
"execution_count": null,
|
| 880 |
+
"id": "f54736ea",
|
| 881 |
+
"metadata": {},
|
| 882 |
+
"outputs": [],
|
| 883 |
+
"source": [
|
| 884 |
+
"# Load the pre-trained model for evaluation\n",
|
| 885 |
+
"device = \"cuda\"\n",
|
| 886 |
+
"model_path = \"path/to/model/save.pth\"\n",
|
| 887 |
+
"model = resnet18_1D().to(device)\n",
|
| 888 |
+
"model.load_state_dict(torch.load(model_path, map_location=device))\n",
|
| 889 |
+
"model.eval()\n"
|
| 890 |
+
]
|
| 891 |
+
},
|
| 892 |
+
{
|
| 893 |
+
"cell_type": "code",
|
| 894 |
+
"execution_count": null,
|
| 895 |
+
"id": "dd5e7fee",
|
| 896 |
+
"metadata": {},
|
| 897 |
+
"outputs": [],
|
| 898 |
+
"source": [
|
| 899 |
+
"full_dataset = ConcatDataset([\n",
|
| 900 |
+
" WidebandSignalDataset(signal_ids=train, return_snrs=True),\n",
|
| 901 |
+
" WidebandSignalDataset(signal_ids=validation, return_snrs=True),\n",
|
| 902 |
+
" WidebandSignalDataset(signal_ids=test, return_snrs=True)\n",
|
| 903 |
+
"])"
|
| 904 |
+
]
|
| 905 |
+
},
|
| 906 |
+
{
|
| 907 |
+
"cell_type": "code",
|
| 908 |
+
"execution_count": null,
|
| 909 |
+
"id": "173f9a8c",
|
| 910 |
+
"metadata": {},
|
| 911 |
+
"outputs": [],
|
| 912 |
+
"source": [
|
| 913 |
+
"full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)"
|
| 914 |
+
]
|
| 915 |
+
},
|
| 916 |
+
{
|
| 917 |
+
"cell_type": "code",
|
| 918 |
+
"execution_count": null,
|
| 919 |
+
"id": "95f711d0",
|
| 920 |
+
"metadata": {},
|
| 921 |
+
"outputs": [],
|
| 922 |
+
"source": [
|
| 923 |
+
"def expand_true(array, distance=1):\n",
|
| 924 |
+
" # Create kernel of appropriate size\n",
|
| 925 |
+
" kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
|
| 926 |
+
" array = array.unsqueeze(1).float() # Add channel dimension\n",
|
| 927 |
+
" result = F.conv1d(array, kernel, padding=distance)\n",
|
| 928 |
+
" result = result.squeeze(1) # Remove the extra dimension\n",
|
| 929 |
+
" return result > 0\n",
|
| 930 |
+
"\n",
|
| 931 |
+
"def get_true_groups(tensor, device):\n",
|
| 932 |
+
" assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
|
| 933 |
+
" all_groups = []\n",
|
| 934 |
+
" for i in range(tensor.size(0)):\n",
|
| 935 |
+
" item = tensor[i]\n",
|
| 936 |
+
" item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
|
| 937 |
+
" diffs = item.float().diff()\n",
|
| 938 |
+
" starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
|
| 939 |
+
" ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
|
| 940 |
+
" groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
|
| 941 |
+
" all_groups.append(groups)\n",
|
| 942 |
+
" return all_groups\n",
|
| 943 |
+
"\n",
|
| 944 |
+
"def calculate_iou(box1, box2):\n",
|
| 945 |
+
" intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
|
| 946 |
+
" union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
|
| 947 |
+
" return intersection / union if union != 0 else 0\n",
|
| 948 |
+
"\n",
|
| 949 |
+
"def match_targets(targets, preds):\n",
|
| 950 |
+
" ious = []\n",
|
| 951 |
+
" for target in targets:\n",
|
| 952 |
+
" iou_targets = []\n",
|
| 953 |
+
" for pred in preds:\n",
|
| 954 |
+
" iou_targets.append(calculate_iou(target, pred))\n",
|
| 955 |
+
" ious.append(iou_targets)\n",
|
| 956 |
+
" cost_matrix = np.array(ious)\n",
|
| 957 |
+
" row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
|
| 958 |
+
" return row_ind, col_ind\n",
|
| 959 |
+
"\n",
|
| 960 |
+
"def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
|
| 961 |
+
" ious = [0 for _ in target_boxes]\n",
|
| 962 |
+
" matching_dict = dict(zip(*matching))\n",
|
| 963 |
+
" for target_index, target_box in enumerate(target_boxes):\n",
|
| 964 |
+
" if target_index in matching_dict:\n",
|
| 965 |
+
" pred_index = matching_dict[target_index]\n",
|
| 966 |
+
" if pred_index < len(prediction_boxes):\n",
|
| 967 |
+
" box1 = target_box\n",
|
| 968 |
+
" box2 = prediction_boxes[pred_index]\n",
|
| 969 |
+
" ious[target_index] = calculate_iou(box1, box2)\n",
|
| 970 |
+
" return ious\n"
|
| 971 |
+
]
|
| 972 |
+
},
|
| 973 |
+
{
|
| 974 |
+
"cell_type": "code",
|
| 975 |
+
"execution_count": null,
|
| 976 |
+
"id": "40ec3d9f",
|
| 977 |
+
"metadata": {},
|
| 978 |
+
"outputs": [],
|
| 979 |
+
"source": [
|
| 980 |
+
"def evaluate(predictor, data_loader, device=\"cuda\"):\n",
|
| 981 |
+
" iou_thresholds = [0.5, 0.7, 0.9]\n",
|
| 982 |
+
" snr_metrics = defaultdict(lambda: {\n",
|
| 983 |
+
" \"iou_sum\": 0.0,\n",
|
| 984 |
+
" \"iou_count\": 0,\n",
|
| 985 |
+
" \"recall_counts\": defaultdict(int),\n",
|
| 986 |
+
" \"total_samples\": defaultdict(int),\n",
|
| 987 |
+
" \"correct_pixels\": 0,\n",
|
| 988 |
+
" \"total_pixels\": 0\n",
|
| 989 |
+
" })\n",
|
| 990 |
+
" total_iou_sum, total_iou_count = 0.0, 0\n",
|
| 991 |
+
" total_correct_pixels, total_total_pixels = 0, 0\n",
|
| 992 |
+
" total_recall_counts = defaultdict(int)\n",
|
| 993 |
+
" total_samples = defaultdict(int)\n",
|
| 994 |
+
"\n",
|
| 995 |
+
" for batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
|
| 996 |
+
" if len(batch) == 3:\n",
|
| 997 |
+
" inputs, masks, snrs_in_batch = batch\n",
|
| 998 |
+
" else:\n",
|
| 999 |
+
" inputs, masks = batch\n",
|
| 1000 |
+
" snrs_in_batch = [0] * len(inputs) # Default SNR if not provided\n",
|
| 1001 |
+
"\n",
|
| 1002 |
+
" inputs = inputs.to(device)\n",
|
| 1003 |
+
" masks = masks.to(device)\n",
|
| 1004 |
+
" outputs = predictor(inputs)\n",
|
| 1005 |
+
"\n",
|
| 1006 |
+
" for i in range(len(inputs)):\n",
|
| 1007 |
+
" mask = masks[i]\n",
|
| 1008 |
+
" output = outputs[i]\n",
|
| 1009 |
+
"\n",
|
| 1010 |
+
" # Resize output to match mask shape if necessary\n",
|
| 1011 |
+
" if output.numel() != mask.numel():\n",
|
| 1012 |
+
" output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
|
| 1013 |
+
"\n",
|
| 1014 |
+
" thresholded_output = (output >= 0.5).float()\n",
|
| 1015 |
+
"\n",
|
| 1016 |
+
" correct_pixels = (thresholded_output == mask).sum().item()\n",
|
| 1017 |
+
" total_pixels = mask.numel()\n",
|
| 1018 |
+
" total_correct_pixels += correct_pixels\n",
|
| 1019 |
+
" total_total_pixels += total_pixels\n",
|
| 1020 |
+
"\n",
|
| 1021 |
+
" # Get SNR value and round it to the nearest integer\n",
|
| 1022 |
+
" snr = snrs_in_batch[i]\n",
|
| 1023 |
+
" if isinstance(snr, torch.Tensor):\n",
|
| 1024 |
+
" snr = snr.item()\n",
|
| 1025 |
+
" snr = int(round(snr)) # Round SNR to the nearest integer\n",
|
| 1026 |
+
"\n",
|
| 1027 |
+
" snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
|
| 1028 |
+
" snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
|
| 1029 |
+
"\n",
|
| 1030 |
+
" target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
|
| 1031 |
+
" pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
|
| 1032 |
+
" if not target_boxes or not pred_boxes:\n",
|
| 1033 |
+
" continue\n",
|
| 1034 |
+
" matching = match_targets(target_boxes, pred_boxes)\n",
|
| 1035 |
+
" matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
|
| 1036 |
+
"\n",
|
| 1037 |
+
" snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
|
| 1038 |
+
" snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
|
| 1039 |
+
" total_iou_sum += sum(matched_ious)\n",
|
| 1040 |
+
" total_iou_count += len(matched_ious)\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" for th in iou_thresholds:\n",
|
| 1043 |
+
" true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
|
| 1044 |
+
" snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
|
| 1045 |
+
" snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
|
| 1046 |
+
" total_recall_counts[th] += true_positives\n",
|
| 1047 |
+
" total_samples[th] += len(target_boxes)\n",
|
| 1048 |
+
"\n",
|
| 1049 |
+
" # Calculate overall metrics\n",
|
| 1050 |
+
" overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
|
| 1051 |
+
" overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
|
| 1052 |
+
" overall_recall = {\n",
|
| 1053 |
+
" th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0\n",
|
| 1054 |
+
" for th in iou_thresholds\n",
|
| 1055 |
+
" }\n",
|
| 1056 |
+
"\n",
|
| 1057 |
+
" # Print overall results\n",
|
| 1058 |
+
" print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
|
| 1059 |
+
" print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
|
| 1060 |
+
" for th in iou_thresholds:\n",
|
| 1061 |
+
" print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
|
| 1062 |
+
"\n",
|
| 1063 |
+
" # Print per-SNR results\n",
|
| 1064 |
+
" for snr in sorted(snr_metrics.keys()):\n",
|
| 1065 |
+
" metrics = snr_metrics[snr]\n",
|
| 1066 |
+
" snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
|
| 1067 |
+
" snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
|
| 1068 |
+
" print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
|
| 1069 |
+
" print(f\" IoU: {snr_iou:.4f}\")\n",
|
| 1070 |
+
" for th in iou_thresholds:\n",
|
| 1071 |
+
" recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
|
| 1072 |
+
" print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
|
| 1073 |
+
"\n",
|
| 1074 |
+
" return snr_metrics\n",
|
| 1075 |
+
"\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
"def model_predictor(signals):\n",
|
| 1078 |
+
" # Use the already loaded model and apply thresholding\n",
|
| 1079 |
+
" return expand_true(model(signals) > 0.5)\n"
|
| 1080 |
+
]
|
| 1081 |
+
},
|
| 1082 |
+
{
|
| 1083 |
+
"cell_type": "code",
|
| 1084 |
+
"execution_count": null,
|
| 1085 |
+
"id": "c7d3aed7",
|
| 1086 |
+
"metadata": {
|
| 1087 |
+
"scrolled": false
|
| 1088 |
+
},
|
| 1089 |
+
"outputs": [],
|
| 1090 |
+
"source": [
|
| 1091 |
+
"# Run evaluation on the full dataset\n",
|
| 1092 |
+
"snr_metrics = evaluate(model_predictor, full_loader, device=device)"
|
| 1093 |
+
]
|
| 1094 |
+
},
|
| 1095 |
+
{
|
| 1096 |
+
"cell_type": "markdown",
|
| 1097 |
+
"id": "2fd3ba0e",
|
| 1098 |
+
"metadata": {},
|
| 1099 |
+
"source": [
|
| 1100 |
+
"### Save and plot"
|
| 1101 |
+
]
|
| 1102 |
+
},
|
| 1103 |
+
{
|
| 1104 |
+
"cell_type": "code",
|
| 1105 |
+
"execution_count": null,
|
| 1106 |
+
"id": "aef69113",
|
| 1107 |
+
"metadata": {},
|
| 1108 |
+
"outputs": [],
|
| 1109 |
+
"source": [
|
| 1110 |
+
"import os\n",
|
| 1111 |
+
"import json\n",
|
| 1112 |
+
"import matplotlib.pyplot as plt\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
"def save_results_and_plot(snr_metrics, save_path):\n",
|
| 1115 |
+
" \"\"\"\n",
|
| 1116 |
+
" Saves evaluation results to a JSON file and generates plots for Accuracy, IoU, and Recall vs. SNR.\n",
|
| 1117 |
+
" Sets x-axis limits to range from -9 dB to 12 dB to eliminate blank space on the right.\n",
|
| 1118 |
+
"\n",
|
| 1119 |
+
" Args:\n",
|
| 1120 |
+
" snr_metrics (dict): The evaluation results obtained from the evaluate function.\n",
|
| 1121 |
+
" save_path (str): The directory path where results and plots will be saved.\n",
|
| 1122 |
+
"\n",
|
| 1123 |
+
" Outputs:\n",
|
| 1124 |
+
" - evaluation_results.json\n",
|
| 1125 |
+
" - accuracy_vs_snr.png and .svg\n",
|
| 1126 |
+
" - iou_vs_snr.png and .svg\n",
|
| 1127 |
+
" - recall_vs_snr.png and .svg\n",
|
| 1128 |
+
" \"\"\"\n",
|
| 1129 |
+
" # Ensure the directory exists\n",
|
| 1130 |
+
" os.makedirs(save_path, exist_ok=True)\n",
|
| 1131 |
+
" \n",
|
| 1132 |
+
" # Extract data from snr_metrics\n",
|
| 1133 |
+
" snr_list = sorted(snr_metrics.keys())\n",
|
| 1134 |
+
" accuracy_list = []\n",
|
| 1135 |
+
" iou_list = []\n",
|
| 1136 |
+
" recall_05 = []\n",
|
| 1137 |
+
" recall_07 = []\n",
|
| 1138 |
+
" recall_09 = []\n",
|
| 1139 |
+
" \n",
|
| 1140 |
+
" # Prepare data for JSON serialization\n",
|
| 1141 |
+
" json_data = {}\n",
|
| 1142 |
+
" \n",
|
| 1143 |
+
" for snr in snr_list:\n",
|
| 1144 |
+
" metrics = snr_metrics[snr]\n",
|
| 1145 |
+
" snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
|
| 1146 |
+
" snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
|
| 1147 |
+
" recall_at_05 = metrics[\"recall_counts\"][0.5] / metrics[\"total_samples\"][0.5] if metrics[\"total_samples\"][0.5] > 0 else 0\n",
|
| 1148 |
+
" recall_at_07 = metrics[\"recall_counts\"][0.7] / metrics[\"total_samples\"][0.7] if metrics[\"total_samples\"][0.7] > 0 else 0\n",
|
| 1149 |
+
" recall_at_09 = metrics[\"recall_counts\"][0.9] / metrics[\"total_samples\"][0.9] if metrics[\"total_samples\"][0.9] > 0 else 0\n",
|
| 1150 |
+
"\n",
|
| 1151 |
+
" # Append to lists for plotting\n",
|
| 1152 |
+
" accuracy_list.append(snr_accuracy)\n",
|
| 1153 |
+
" iou_list.append(snr_iou)\n",
|
| 1154 |
+
" recall_05.append(recall_at_05)\n",
|
| 1155 |
+
" recall_07.append(recall_at_07)\n",
|
| 1156 |
+
" recall_09.append(recall_at_09)\n",
|
| 1157 |
+
"\n",
|
| 1158 |
+
" # Prepare data for JSON\n",
|
| 1159 |
+
" json_data[snr] = {\n",
|
| 1160 |
+
" \"accuracy\": snr_accuracy,\n",
|
| 1161 |
+
" \"iou\": snr_iou,\n",
|
| 1162 |
+
" \"recall\": {\n",
|
| 1163 |
+
" \"0.5\": recall_at_05,\n",
|
| 1164 |
+
" \"0.7\": recall_at_07,\n",
|
| 1165 |
+
" \"0.9\": recall_at_09,\n",
|
| 1166 |
+
" }\n",
|
| 1167 |
+
" }\n",
|
| 1168 |
+
" \n",
|
| 1169 |
+
" # Save json_data to JSON file\n",
|
| 1170 |
+
" json_file_path = os.path.join(save_path, 'evaluation_results.json')\n",
|
| 1171 |
+
" with open(json_file_path, 'w') as json_file:\n",
|
| 1172 |
+
" json.dump(json_data, json_file, indent=4)\n",
|
| 1173 |
+
" \n",
|
| 1174 |
+
" # Plot Accuracy vs. SNR\n",
|
| 1175 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1176 |
+
" plt.plot(snr_list, accuracy_list, marker='o', label='Accuracy')\n",
|
| 1177 |
+
" plt.title('Accuracy vs. SNR')\n",
|
| 1178 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1179 |
+
" plt.ylabel('Accuracy (%)')\n",
|
| 1180 |
+
" plt.grid(True)\n",
|
| 1181 |
+
" plt.legend()\n",
|
| 1182 |
+
" \n",
|
| 1183 |
+
" # Set x-axis limits\n",
|
| 1184 |
+
" plt.xlim(-9, 12)\n",
|
| 1185 |
+
" \n",
|
| 1186 |
+
" # Save the plot\n",
|
| 1187 |
+
" accuracy_png_path = os.path.join(save_path, 'accuracy_vs_snr.png')\n",
|
| 1188 |
+
" accuracy_svg_path = os.path.join(save_path, 'accuracy_vs_snr.svg')\n",
|
| 1189 |
+
" plt.savefig(accuracy_png_path, format='png', bbox_inches='tight')\n",
|
| 1190 |
+
" plt.savefig(accuracy_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1191 |
+
" \n",
|
| 1192 |
+
" plt.show()\n",
|
| 1193 |
+
" plt.close()\n",
|
| 1194 |
+
" \n",
|
| 1195 |
+
" # Plot IoU vs. SNR\n",
|
| 1196 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1197 |
+
" plt.plot(snr_list, iou_list, marker='o', color='orange', label='IoU')\n",
|
| 1198 |
+
" plt.title('IoU vs. SNR')\n",
|
| 1199 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1200 |
+
" plt.ylabel('IoU')\n",
|
| 1201 |
+
" plt.grid(True)\n",
|
| 1202 |
+
" plt.legend()\n",
|
| 1203 |
+
" \n",
|
| 1204 |
+
" # Set x-axis limits\n",
|
| 1205 |
+
" plt.xlim(-9, 12)\n",
|
| 1206 |
+
" \n",
|
| 1207 |
+
" # Save the plot\n",
|
| 1208 |
+
" iou_png_path = os.path.join(save_path, 'iou_vs_snr.png')\n",
|
| 1209 |
+
" iou_svg_path = os.path.join(save_path, 'iou_vs_snr.svg')\n",
|
| 1210 |
+
" plt.savefig(iou_png_path, format='png', bbox_inches='tight')\n",
|
| 1211 |
+
" plt.savefig(iou_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1212 |
+
" \n",
|
| 1213 |
+
" plt.show()\n",
|
| 1214 |
+
" plt.close()\n",
|
| 1215 |
+
" \n",
|
| 1216 |
+
" # Plot Recall at Different IoU Thresholds vs. SNR\n",
|
| 1217 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1218 |
+
" plt.plot(snr_list, recall_05, marker='o', label='Recall @ IoU 0.5')\n",
|
| 1219 |
+
" plt.plot(snr_list, recall_07, marker='s', label='Recall @ IoU 0.7')\n",
|
| 1220 |
+
" plt.plot(snr_list, recall_09, marker='^', label='Recall @ IoU 0.9')\n",
|
| 1221 |
+
" plt.title('Recall at Different IoU Thresholds vs. SNR')\n",
|
| 1222 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1223 |
+
" plt.ylabel('Recall')\n",
|
| 1224 |
+
" plt.grid(True)\n",
|
| 1225 |
+
" plt.legend()\n",
|
| 1226 |
+
" \n",
|
| 1227 |
+
" # Set x-axis limits\n",
|
| 1228 |
+
" plt.xlim(-9, 12)\n",
|
| 1229 |
+
" \n",
|
| 1230 |
+
" # Save the plot\n",
|
| 1231 |
+
" recall_png_path = os.path.join(save_path, 'recall_vs_snr.png')\n",
|
| 1232 |
+
" recall_svg_path = os.path.join(save_path, 'recall_vs_snr.svg')\n",
|
| 1233 |
+
" plt.savefig(recall_png_path, format='png', bbox_inches='tight')\n",
|
| 1234 |
+
" plt.savefig(recall_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1235 |
+
" \n",
|
| 1236 |
+
" plt.show()\n",
|
| 1237 |
+
" plt.close()\n"
|
| 1238 |
+
]
|
| 1239 |
+
},
|
| 1240 |
+
{
|
| 1241 |
+
"cell_type": "code",
|
| 1242 |
+
"execution_count": null,
|
| 1243 |
+
"id": "c9595d5e",
|
| 1244 |
+
"metadata": {},
|
| 1245 |
+
"outputs": [],
|
| 1246 |
+
"source": [
|
| 1247 |
+
"# Assuming snr_metrics is the output from the evaluate function\n",
|
| 1248 |
+
"# Set the save path\n",
|
| 1249 |
+
"save_path = 'CMuSeNet_BIGRED_results'\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
"# Call the function\n",
|
| 1252 |
+
"save_results_and_plot(snr_metrics, save_path)\n"
|
| 1253 |
+
]
|
| 1254 |
+
}
|
| 1255 |
+
],
|
| 1256 |
+
"metadata": {
|
| 1257 |
+
"kernelspec": {
|
| 1258 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1259 |
+
"language": "python",
|
| 1260 |
+
"name": "python3"
|
| 1261 |
+
},
|
| 1262 |
+
"language_info": {
|
| 1263 |
+
"codemirror_mode": {
|
| 1264 |
+
"name": "ipython",
|
| 1265 |
+
"version": 3
|
| 1266 |
+
},
|
| 1267 |
+
"file_extension": ".py",
|
| 1268 |
+
"mimetype": "text/x-python",
|
| 1269 |
+
"name": "python",
|
| 1270 |
+
"nbconvert_exporter": "python",
|
| 1271 |
+
"pygments_lexer": "ipython3",
|
| 1272 |
+
"version": "3.10.9"
|
| 1273 |
+
}
|
| 1274 |
+
},
|
| 1275 |
+
"nbformat": 4,
|
| 1276 |
+
"nbformat_minor": 5
|
| 1277 |
+
}
|
CMuSeNet_Indoor_OTA.ipynb
ADDED
|
@@ -0,0 +1,1658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b5007b71",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"### Initialization"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "3e6b1226",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"### Initialization block\n",
|
| 19 |
+
"from pathlib import Path\n",
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"import json\n",
|
| 22 |
+
"import torch\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"from tqdm import tqdm\n",
|
| 25 |
+
"import math\n",
|
| 26 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"STFT_LENGTH = 16 * 1024\n",
|
| 29 |
+
"DATA_DIR = Path(\"/data/OTA_reduced/\")\n",
|
| 30 |
+
"SAMPLE_RATE = 20e6\n",
|
| 31 |
+
"MODULATIONS = [\"QPSK\", \"BPSK\", \"2-FSK\"]\n",
|
| 32 |
+
"MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}\n",
|
| 33 |
+
"NUMBER_OF_MODULATIONS = len(MODULATIONS)\n",
|
| 34 |
+
"MASK_SIZE = int(STFT_LENGTH)\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"from matplotlib.mlab import psd as apply_psd\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"def calc_sig_power(signal, meta, noise_power=-132.065):\n",
|
| 39 |
+
" \n",
|
| 40 |
+
" noise_floor_linear = 10 ** (noise_power / 10)\n",
|
| 41 |
+
" (psd, frequencies) = apply_psd(signal, Fs=SAMPLE_RATE, NFFT=1024)\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" signal_position = []\n",
|
| 45 |
+
"\n",
|
| 46 |
+
" body = meta[\"body\"]\n",
|
| 47 |
+
" device = meta[\"client_id\"]\n",
|
| 48 |
+
" bandwidth, frequency_offset = body[\"bandwidth\"] + 20e3, body[\"frequency_offset\"]\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" \n",
|
| 51 |
+
" below_freq = frequency_offset-bandwidth/2\n",
|
| 52 |
+
" upper_freq = frequency_offset+bandwidth/2\n",
|
| 53 |
+
" sum_power_dbs = 0\n",
|
| 54 |
+
" freq_count = 0\n",
|
| 55 |
+
" \n",
|
| 56 |
+
" for idx, (power, freq) in enumerate(zip(psd, frequencies)):\n",
|
| 57 |
+
" if below_freq <= freq <= upper_freq:\n",
|
| 58 |
+
" freq_count+=1\n",
|
| 59 |
+
" sum_power_dbs+=(power)\n",
|
| 60 |
+
" return sum_power_dbs\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"# noise_power is measured from noise signal collection\n",
|
| 63 |
+
"def calc_snr(signal_power, noise_power=-132.065):\n",
|
| 64 |
+
" noise_floor_linear = 10 ** (noise_power / 10)\n",
|
| 65 |
+
" snr_linear = signal_power / (noise_floor_linear * 1024)\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" snr_db = 10 * np.log10(snr_linear)\n",
|
| 68 |
+
" \n",
|
| 69 |
+
" return round(snr_db)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"def convert_metadata_format_real_to_simulated(signal, metadata):\n",
|
| 72 |
+
" name_mapping = {\"2FSK\": \"2-FSK\"}\n",
|
| 73 |
+
" return [\n",
|
| 74 |
+
" {\n",
|
| 75 |
+
" \"fc\": body[\"frequency_offset\"], \n",
|
| 76 |
+
" \"bw\": body[\"bandwidth\"] + 20e3,\n",
|
| 77 |
+
" \"mod\": name_mapping.get(body[\"modulation\"], body[\"modulation\"]),\n",
|
| 78 |
+
" \"snr\": calc_snr(calc_sig_power(signal, meta))\n",
|
| 79 |
+
" } for meta in metadata if (body := meta[\"body\"])\n",
|
| 80 |
+
" ]\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"def load_data(signal_id, load_metadata_only=False):\n",
|
| 83 |
+
" if not load_metadata_only:\n",
|
| 84 |
+
" signal_path = DATA_DIR / str(signal_id) / \"data.npy\"\n",
|
| 85 |
+
" if not signal_path.exists():\n",
|
| 86 |
+
" raise FileNotFoundError(f\"Signal file {signal_path} not found.\")\n",
|
| 87 |
+
" signal = np.load(signal_path)\n",
|
| 88 |
+
" else:\n",
|
| 89 |
+
" signal = None\n",
|
| 90 |
+
" with open(DATA_DIR / str(signal_id) / \"meta-data.json\") as f:\n",
|
| 91 |
+
" meta = json.load(f)\n",
|
| 92 |
+
" if isinstance(meta, dict):\n",
|
| 93 |
+
" meta = [meta]\n",
|
| 94 |
+
" return signal, convert_metadata_format_real_to_simulated(signal, meta)\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" \n",
|
| 98 |
+
"def _get_all_numbered_dirs(root_dir):\n",
|
| 99 |
+
" dirs = []\n",
|
| 100 |
+
" for directory in root_dir.iterdir():\n",
|
| 101 |
+
" dirs.append(int(directory.name))\n",
|
| 102 |
+
" dirs.sort()\n",
|
| 103 |
+
" return dirs\n",
|
| 104 |
+
" \n",
|
| 105 |
+
" \n",
|
| 106 |
+
"def process_metadata(metadata):\n",
|
| 107 |
+
" scaled_metadata = [\n",
|
| 108 |
+
" {\n",
|
| 109 |
+
" \"position\": (SAMPLE_RATE/2 + i['fc'], i['bw']),\n",
|
| 110 |
+
" \"mod\": i[\"mod\"],\n",
|
| 111 |
+
" \"snr\": i[\"snr\"],\n",
|
| 112 |
+
" \"bw\": int(i['bw'])\n",
|
| 113 |
+
" }\n",
|
| 114 |
+
" for i in metadata\n",
|
| 115 |
+
" ]\n",
|
| 116 |
+
" return scaled_metadata\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"def process_signal(signal):\n",
|
| 120 |
+
" signal = signal[:STFT_LENGTH]\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" signal = np.fft.fft(signal)\n",
|
| 123 |
+
" signal = np.fft.fftshift(signal)\n",
|
| 124 |
+
" signal /= np.max(np.abs(signal))\n",
|
| 125 |
+
" return signal"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"id": "440b802c",
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"source": [
|
| 133 |
+
"### Data Loading"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"id": "31bc3770",
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"class WidebandSignalDataset(torch.utils.data.Dataset):\n",
|
| 144 |
+
" def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snrs=False):\n",
|
| 145 |
+
" self.mask_size = mask_size\n",
|
| 146 |
+
" self.signal_ids = signal_ids\n",
|
| 147 |
+
" self.return_snrs = return_snrs\n",
|
| 148 |
+
" self.snrs = []\n",
|
| 149 |
+
" loaded_data = []\n",
|
| 150 |
+
" \n",
|
| 151 |
+
" for signal_id in tqdm(self.signal_ids):\n",
|
| 152 |
+
" loaded_data.append(self.process_signal(signal_id))\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" self.loaded_data = loaded_data\n",
|
| 155 |
+
"\n",
|
| 156 |
+
" def __len__(self):\n",
|
| 157 |
+
" return len(self.signal_ids)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" def __getitem__(self, index):\n",
|
| 160 |
+
" if self.return_snrs:\n",
|
| 161 |
+
" signal, masks, snr = self.loaded_data[index]\n",
|
| 162 |
+
" else:\n",
|
| 163 |
+
" signal, masks = self.loaded_data[index]\n",
|
| 164 |
+
"\n",
|
| 165 |
+
" # Ensure `signal` is complex and `masks` is real-valued\n",
|
| 166 |
+
" if not isinstance(signal, torch.Tensor):\n",
|
| 167 |
+
" signal = torch.from_numpy(signal).type(torch.complex64)\n",
|
| 168 |
+
" if not isinstance(masks, torch.Tensor):\n",
|
| 169 |
+
" masks = torch.from_numpy(masks).type(torch.FloatTensor)\n",
|
| 170 |
+
"\n",
|
| 171 |
+
" if self.return_snrs:\n",
|
| 172 |
+
" if not isinstance(snr, torch.Tensor):\n",
|
| 173 |
+
" snr = torch.tensor(snr).type(torch.FloatTensor)\n",
|
| 174 |
+
" return signal, masks, snr\n",
|
| 175 |
+
" else:\n",
|
| 176 |
+
" return signal, masks\n",
|
| 177 |
+
"\n",
|
| 178 |
+
" def process_signal(self, signal_id):\n",
|
| 179 |
+
" # Load data and metadata\n",
|
| 180 |
+
" signal, metadata = load_data(signal_id)\n",
|
| 181 |
+
" \n",
|
| 182 |
+
" # Process the metadata and create masks\n",
|
| 183 |
+
" scaled_metadata = process_metadata(metadata)\n",
|
| 184 |
+
" snrs = [meta['snr'] for meta in scaled_metadata]\n",
|
| 185 |
+
" average_snr = sum(snrs) / len(snrs) if snrs else 0\n",
|
| 186 |
+
" \n",
|
| 187 |
+
" # Convert signal to complex format and normalize it\n",
|
| 188 |
+
" signal = process_signal(signal) # `process_signal` should return np.ndarray (complex)\n",
|
| 189 |
+
" signal = torch.from_numpy(signal).type(torch.complex64) # Convert to complex tensor\n",
|
| 190 |
+
" \n",
|
| 191 |
+
" # Generate binary mask for each frequency segment\n",
|
| 192 |
+
" masks = np.zeros(self.mask_size, dtype=np.float32)\n",
|
| 193 |
+
" scale_ratio = self.mask_size / SAMPLE_RATE\n",
|
| 194 |
+
" for meta in scaled_metadata:\n",
|
| 195 |
+
" f, b = meta['position']\n",
|
| 196 |
+
" x1 = math.floor((f - b / 2) * scale_ratio)\n",
|
| 197 |
+
" x2 = math.ceil((f + b / 2) * scale_ratio)\n",
|
| 198 |
+
" masks[x1:x2] = 1\n",
|
| 199 |
+
" \n",
|
| 200 |
+
" if self.return_snrs:\n",
|
| 201 |
+
" return signal, masks, average_snr\n",
|
| 202 |
+
" else:\n",
|
| 203 |
+
" return signal, masks\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"# Train test split 80 - 10 - 10\n",
|
| 207 |
+
"train, test, validation = [], [], [] \n",
|
| 208 |
+
"total_signals = len([i for i in DATA_DIR.iterdir()])\n",
|
| 209 |
+
"for index, signal in enumerate(_get_all_numbered_dirs(DATA_DIR)):\n",
|
| 210 |
+
" if index <= 0.80 * total_signals:\n",
|
| 211 |
+
" train.append(signal)\n",
|
| 212 |
+
" elif index <= 0.9 * total_signals:\n",
|
| 213 |
+
" validation.append(signal)\n",
|
| 214 |
+
" else:\n",
|
| 215 |
+
" test.append(signal)\n",
|
| 216 |
+
" \n",
|
| 217 |
+
"print(\"Train\", len(train))\n",
|
| 218 |
+
"print(\"Validation\", len(validation))\n",
|
| 219 |
+
"print(\"Test\", len(test))\n"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "markdown",
|
| 224 |
+
"id": "3e74df1a",
|
| 225 |
+
"metadata": {},
|
| 226 |
+
"source": [
|
| 227 |
+
"### Check if complex value"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "code",
|
| 232 |
+
"execution_count": null,
|
| 233 |
+
"id": "23f75344",
|
| 234 |
+
"metadata": {},
|
| 235 |
+
"outputs": [],
|
| 236 |
+
"source": [
|
| 237 |
+
"def test_single_signal_loading(signal_id):\n",
|
| 238 |
+
" # Load a single signal and process it\n",
|
| 239 |
+
" signal, metadata = load_data(signal_id)\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" # Process the signal: Apply any necessary preprocessing, and convert to complex format\n",
|
| 242 |
+
" processed_signal = process_signal(signal) # This should return a complex np.ndarray\n",
|
| 243 |
+
" complex_signal = torch.from_numpy(processed_signal).type(torch.complex64)\n",
|
| 244 |
+
" \n",
|
| 245 |
+
" # Check if the signal is complex\n",
|
| 246 |
+
" print(\"Loaded Signal ID:\", signal_id)\n",
|
| 247 |
+
" print(\"Signal Type:\", complex_signal.dtype)\n",
|
| 248 |
+
" print(\"Signal Shape:\", complex_signal.shape)\n",
|
| 249 |
+
" \n",
|
| 250 |
+
" # Generate the mask as you would in WidebandSignalDataset\n",
|
| 251 |
+
" scaled_metadata = process_metadata(metadata)\n",
|
| 252 |
+
" masks = np.zeros(MASK_SIZE, dtype=np.float32)\n",
|
| 253 |
+
" scale_ratio = MASK_SIZE / SAMPLE_RATE\n",
|
| 254 |
+
" for meta in scaled_metadata:\n",
|
| 255 |
+
" f, b = meta['position']\n",
|
| 256 |
+
" x1 = math.floor((f - b / 2) * scale_ratio)\n",
|
| 257 |
+
" x2 = math.ceil((f + b / 2) * scale_ratio)\n",
|
| 258 |
+
" masks[x1:x2] = 1\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" # Convert mask to tensor\n",
|
| 261 |
+
" mask_tensor = torch.from_numpy(masks).type(torch.FloatTensor)\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" # Output information about the mask\n",
|
| 264 |
+
" print(\"Mask Shape:\", mask_tensor.shape)\n",
|
| 265 |
+
" print(\"Mask Type:\", mask_tensor.dtype)\n",
|
| 266 |
+
" \n",
|
| 267 |
+
" return complex_signal, mask_tensor\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"# Test with a specific signal_id (replace with an actual ID from your data)\n",
|
| 270 |
+
"test_signal_id = train[0] # Assuming `train` list contains valid signal IDs\n",
|
| 271 |
+
"complex_signal, mask_tensor = test_single_signal_loading(test_signal_id)\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Optional: Check a sample value to confirm it's complex\n",
|
| 274 |
+
"print(\"Sample value from signal tensor:\", complex_signal[0])"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "code",
|
| 279 |
+
"execution_count": null,
|
| 280 |
+
"id": "1cec9c6e",
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"outputs": [],
|
| 283 |
+
"source": [
|
| 284 |
+
"train_dataset = WidebandSignalDataset(signal_ids=train)\n",
|
| 285 |
+
"validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
|
| 286 |
+
"test_dataset = WidebandSignalDataset(signal_ids=test)"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "markdown",
|
| 291 |
+
"id": "e0900d4e",
|
| 292 |
+
"metadata": {},
|
| 293 |
+
"source": [
|
| 294 |
+
"### Check SNR"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"execution_count": null,
|
| 300 |
+
"id": "2fbee106",
|
| 301 |
+
"metadata": {},
|
| 302 |
+
"outputs": [],
|
| 303 |
+
"source": [
|
| 304 |
+
"import matplotlib.pyplot as plt\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"# For Train Dataset\n",
|
| 307 |
+
"train_snrs = train_dataset.snrs\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"# Plot Histogram of SNRs in Train Dataset\n",
|
| 310 |
+
"plt.figure(figsize=(10, 6))\n",
|
| 311 |
+
"plt.hist(train_snrs, bins=range(int(min(train_snrs)), int(max(train_snrs)) + 1), edgecolor='black')\n",
|
| 312 |
+
"plt.title('Histogram of SNRs in Train Dataset')\n",
|
| 313 |
+
"plt.xlabel('SNR (dB)')\n",
|
| 314 |
+
"plt.ylabel('Number of Samples')\n",
|
| 315 |
+
"plt.grid(True)\n",
|
| 316 |
+
"plt.show()\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"# Print SNR Range\n",
|
| 319 |
+
"print('Train Dataset SNR range: {} dB to {} dB'.format(min(train_snrs), max(train_snrs)))\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# For Validation Dataset\n",
|
| 322 |
+
"validation_snrs = validation_dataset.snrs\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"# Plot Histogram of SNRs in Validation Dataset\n",
|
| 325 |
+
"plt.figure(figsize=(10, 6))\n",
|
| 326 |
+
"plt.hist(validation_snrs, bins=range(int(min(validation_snrs)), int(max(validation_snrs)) + 1), edgecolor='black')\n",
|
| 327 |
+
"plt.title('Histogram of SNRs in Validation Dataset')\n",
|
| 328 |
+
"plt.xlabel('SNR (dB)')\n",
|
| 329 |
+
"plt.ylabel('Number of Samples')\n",
|
| 330 |
+
"plt.grid(True)\n",
|
| 331 |
+
"plt.show()\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"# Print SNR Range\n",
|
| 334 |
+
"print('Validation Dataset SNR range: {} dB to {} dB'.format(min(validation_snrs), max(validation_snrs)))\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"# For Test Dataset\n",
|
| 337 |
+
"test_snrs = test_dataset.snrs\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# Plot Histogram of SNRs in Validation Dataset\n",
|
| 340 |
+
"plt.figure(figsize=(10, 6))\n",
|
| 341 |
+
"plt.hist(test_snrs, bins=range(int(min(test_snrs)), int(max(test_snrs)) + 1), edgecolor='black')\n",
|
| 342 |
+
"plt.title('Histogram of SNRs in Test Dataset')\n",
|
| 343 |
+
"plt.xlabel('SNR (dB)')\n",
|
| 344 |
+
"plt.ylabel('Number of Samples')\n",
|
| 345 |
+
"plt.grid(True)\n",
|
| 346 |
+
"plt.show()\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"# Print SNR Range\n",
|
| 349 |
+
"print('Validation Dataset SNR range: {} dB to {} dB'.format(min(test_snrs), max(test_snrs)))\n"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "markdown",
|
| 354 |
+
"id": "637ae774",
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"source": [
|
| 357 |
+
"### Batch Loading"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "code",
|
| 362 |
+
"execution_count": null,
|
| 363 |
+
"id": "a9af2450",
|
| 364 |
+
"metadata": {},
|
| 365 |
+
"outputs": [],
|
| 366 |
+
"source": [
|
| 367 |
+
"batch_size = 64 # Updated batch size\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 370 |
+
"valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"print(\"Train labels shape:\", len(train_dataset))\n",
|
| 373 |
+
"print(\"Validation labels shape:\", len(validation_dataset))"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"cell_type": "markdown",
|
| 378 |
+
"id": "9a8e09e4",
|
| 379 |
+
"metadata": {},
|
| 380 |
+
"source": [
|
| 381 |
+
"### Early Stop"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"cell_type": "code",
|
| 386 |
+
"execution_count": null,
|
| 387 |
+
"id": "24f79a24",
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"import os\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"class EarlyStopping:\n",
|
| 394 |
+
" def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./path/to/model/save'):\n",
|
| 395 |
+
" self.patience = patience\n",
|
| 396 |
+
" self.verbose = verbose\n",
|
| 397 |
+
" self.delta = delta\n",
|
| 398 |
+
" self.counter = 0\n",
|
| 399 |
+
" self.best_score = None\n",
|
| 400 |
+
" self.early_stop = False\n",
|
| 401 |
+
" self.val_loss_min = float('inf')\n",
|
| 402 |
+
" self.best_model = None\n",
|
| 403 |
+
" self.save_path = save_path\n",
|
| 404 |
+
" os.makedirs(save_path, exist_ok=True)\n",
|
| 405 |
+
" \n",
|
| 406 |
+
" def __call__(self, val_loss, model):\n",
|
| 407 |
+
" score = -val_loss\n",
|
| 408 |
+
"\n",
|
| 409 |
+
" if self.best_score is None:\n",
|
| 410 |
+
" self.best_score = score\n",
|
| 411 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 412 |
+
" elif score < self.best_score + self.delta:\n",
|
| 413 |
+
" self.counter += 1\n",
|
| 414 |
+
" if self.verbose:\n",
|
| 415 |
+
" print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
|
| 416 |
+
" if self.counter >= self.patience:\n",
|
| 417 |
+
" self.early_stop = True\n",
|
| 418 |
+
" else:\n",
|
| 419 |
+
" self.best_score = score\n",
|
| 420 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 421 |
+
" self.counter = 0\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" def save_checkpoint(self, val_loss, model):\n",
|
| 424 |
+
" if self.verbose:\n",
|
| 425 |
+
" print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
|
| 426 |
+
" self.val_loss_min = val_loss\n",
|
| 427 |
+
" self.best_model = model.state_dict()\n",
|
| 428 |
+
" save_path = os.path.join(self.save_path, 'best_model.pth')\n",
|
| 429 |
+
" torch.save(self.best_model, save_path)"
|
| 430 |
+
]
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"cell_type": "markdown",
|
| 434 |
+
"id": "6c3fda74",
|
| 435 |
+
"metadata": {},
|
| 436 |
+
"source": [
|
| 437 |
+
"### Reshape"
|
| 438 |
+
]
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"cell_type": "code",
|
| 442 |
+
"execution_count": null,
|
| 443 |
+
"id": "5fcf91db",
|
| 444 |
+
"metadata": {},
|
| 445 |
+
"outputs": [],
|
| 446 |
+
"source": [
|
| 447 |
+
"import torch.nn as nn\n",
|
| 448 |
+
"import complexPyTorch.complexLayers as cplx\n",
|
| 449 |
+
"import torch.nn.functional as F\n",
|
| 450 |
+
"import torch\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"def reshape_to_2d(data):\n",
|
| 453 |
+
" return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
|
| 454 |
+
]
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"cell_type": "markdown",
|
| 458 |
+
"id": "b7d7562c",
|
| 459 |
+
"metadata": {},
|
| 460 |
+
"source": [
|
| 461 |
+
"### Complex IoU"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"cell_type": "code",
|
| 466 |
+
"execution_count": null,
|
| 467 |
+
"id": "76b9d084",
|
| 468 |
+
"metadata": {},
|
| 469 |
+
"outputs": [],
|
| 470 |
+
"source": [
|
| 471 |
+
"def calculate_iou(pred, target, threshold=0.5):\n",
|
| 472 |
+
" real_pred = (pred.real > threshold).float()\n",
|
| 473 |
+
" imag_pred = (pred.imag > threshold).float()\n",
|
| 474 |
+
" \n",
|
| 475 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
|
| 476 |
+
" \n",
|
| 477 |
+
" intersection = (combined_pred * target).sum(dim=1)\n",
|
| 478 |
+
" union = (combined_pred + target).sum(dim=1) - intersection\n",
|
| 479 |
+
" iou = (intersection / union).mean().item()\n",
|
| 480 |
+
" return iou"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"cell_type": "markdown",
|
| 485 |
+
"id": "64f4063c",
|
| 486 |
+
"metadata": {},
|
| 487 |
+
"source": [
|
| 488 |
+
"### Training"
|
| 489 |
+
]
|
| 490 |
+
},
|
| 491 |
+
{
|
| 492 |
+
"cell_type": "code",
|
| 493 |
+
"execution_count": null,
|
| 494 |
+
"id": "66825110",
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"outputs": [],
|
| 497 |
+
"source": [
|
| 498 |
+
"import time\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"def validate_model(model, valid_loader, criterion):\n",
|
| 501 |
+
" model.eval()\n",
|
| 502 |
+
" running_loss = 0.0\n",
|
| 503 |
+
" iou_scores = []\n",
|
| 504 |
+
" total_correct = 0\n",
|
| 505 |
+
" total_samples = 0\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" with torch.no_grad():\n",
|
| 508 |
+
" for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
|
| 509 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 510 |
+
" masks = masks.to(device)\n",
|
| 511 |
+
" outputs = model(inputs)\n",
|
| 512 |
+
" loss = criterion(outputs, masks)\n",
|
| 513 |
+
" running_loss += loss.item()\n",
|
| 514 |
+
"\n",
|
| 515 |
+
" # Calculate IoU\n",
|
| 516 |
+
" iou = calculate_iou(outputs, masks, threshold=0.5)\n",
|
| 517 |
+
" iou_scores.append(iou)\n",
|
| 518 |
+
" \n",
|
| 519 |
+
" # Calculate accuracy\n",
|
| 520 |
+
" preds = (outputs.real > 0.5).float()\n",
|
| 521 |
+
" correct = (preds == masks).float().sum()\n",
|
| 522 |
+
" total_correct += correct.item()\n",
|
| 523 |
+
" total_samples += masks.numel()\n",
|
| 524 |
+
"\n",
|
| 525 |
+
" val_loss = running_loss / len(valid_loader)\n",
|
| 526 |
+
" mean_iou = sum(iou_scores) / len(iou_scores)\n",
|
| 527 |
+
" accuracy = total_correct / total_samples * 100\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" print(f'Validation Loss: {val_loss:.6f}')\n",
|
| 530 |
+
" print(f'Validation Accuracy: {accuracy:.2f}%')\n",
|
| 531 |
+
"\n",
|
| 532 |
+
" return val_loss, accuracy\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.000001], num_epochs=50, patience=5):\n",
|
| 535 |
+
" train_losses = []\n",
|
| 536 |
+
" val_losses = []\n",
|
| 537 |
+
" val_accuracies = []\n",
|
| 538 |
+
" epoch_durations = []\n",
|
| 539 |
+
" \n",
|
| 540 |
+
" current_lr = initial_lr\n",
|
| 541 |
+
" for lr in lr_steps:\n",
|
| 542 |
+
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
| 543 |
+
" early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
|
| 544 |
+
" print(\"Current learning rate: \", lr)\n",
|
| 545 |
+
" for epoch in range(num_epochs):\n",
|
| 546 |
+
" epoch_start_time = time.time()\n",
|
| 547 |
+
" \n",
|
| 548 |
+
" model.train()\n",
|
| 549 |
+
" running_loss = 0.0\n",
|
| 550 |
+
" for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
|
| 551 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 552 |
+
" masks = masks.to(device)\n",
|
| 553 |
+
" outputs = model(inputs)\n",
|
| 554 |
+
" loss = criterion(outputs, masks)\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" optimizer.zero_grad()\n",
|
| 557 |
+
" loss.backward()\n",
|
| 558 |
+
" optimizer.step()\n",
|
| 559 |
+
"\n",
|
| 560 |
+
" running_loss += loss.item()\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" epoch_loss = running_loss / len(train_loader)\n",
|
| 563 |
+
" train_losses.append(epoch_loss)\n",
|
| 564 |
+
" print(f\"Training Loss: {epoch_loss:.6f}\")\n",
|
| 565 |
+
" \n",
|
| 566 |
+
" val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
|
| 567 |
+
" val_losses.append(val_loss)\n",
|
| 568 |
+
" val_accuracies.append(val_accuracy)\n",
|
| 569 |
+
" early_stopping(val_loss, model)\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" if early_stopping.early_stop:\n",
|
| 572 |
+
" print(\"Early stopping triggered\")\n",
|
| 573 |
+
" break\n",
|
| 574 |
+
"\n",
|
| 575 |
+
" epoch_duration = time.time() - epoch_start_time\n",
|
| 576 |
+
" epoch_durations.append(epoch_duration)\n",
|
| 577 |
+
" if early_stopping.best_model is not None:\n",
|
| 578 |
+
" print(f\"Loading best model from lr {lr}\")\n",
|
| 579 |
+
" model.load_state_dict(early_stopping.best_model)\n",
|
| 580 |
+
" \n",
|
| 581 |
+
" print(\"Training completed.\")\n",
|
| 582 |
+
" print(\"Epoch durations:\", epoch_durations)\n",
|
| 583 |
+
" return model, train_losses, val_losses, val_accuracies, epoch_durations"
|
| 584 |
+
]
|
| 585 |
+
},
|
| 586 |
+
{
|
| 587 |
+
"cell_type": "markdown",
|
| 588 |
+
"id": "0b80cb51",
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"source": [
|
| 591 |
+
"### ResNet-18"
|
| 592 |
+
]
|
| 593 |
+
},
|
| 594 |
+
{
|
| 595 |
+
"cell_type": "code",
|
| 596 |
+
"execution_count": null,
|
| 597 |
+
"id": "2d208cb9",
|
| 598 |
+
"metadata": {},
|
| 599 |
+
"outputs": [],
|
| 600 |
+
"source": [
|
| 601 |
+
"import torch\n",
|
| 602 |
+
"import torch.nn as nn\n",
|
| 603 |
+
"import complexPyTorch.complexLayers as cplx\n",
|
| 604 |
+
"from typing import Optional, Callable, Type, Union, List\n",
|
| 605 |
+
"import torch.nn.functional as F\n",
|
| 606 |
+
"from torch import Tensor\n",
|
| 607 |
+
"\n",
|
| 608 |
+
"def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 609 |
+
" \"\"\"3x3 convolution with padding\"\"\"\n",
|
| 610 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
| 611 |
+
"\n",
|
| 612 |
+
"def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 613 |
+
" \"\"\"1x1 convolution\"\"\"\n",
|
| 614 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"class BasicBlock(nn.Module):\n",
|
| 617 |
+
" expansion = 1\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" def __init__(\n",
|
| 620 |
+
" self,\n",
|
| 621 |
+
" inplanes: int,\n",
|
| 622 |
+
" planes: int,\n",
|
| 623 |
+
" stride: int = 1,\n",
|
| 624 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 625 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 626 |
+
" ) -> None:\n",
|
| 627 |
+
" super(BasicBlock, self).__init__()\n",
|
| 628 |
+
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
|
| 629 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 630 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 631 |
+
" self.conv2 = conv3x3(planes, planes)\n",
|
| 632 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 633 |
+
" self.downsample = downsample\n",
|
| 634 |
+
" self.stride = stride\n",
|
| 635 |
+
"\n",
|
| 636 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 637 |
+
" identity = x\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" out = self.conv1(x)\n",
|
| 640 |
+
" out = self.bn1(out)\n",
|
| 641 |
+
" out = self.relu(out)\n",
|
| 642 |
+
"\n",
|
| 643 |
+
" out = self.conv2(out)\n",
|
| 644 |
+
" out = self.bn2(out)\n",
|
| 645 |
+
"\n",
|
| 646 |
+
" if self.downsample is not None:\n",
|
| 647 |
+
" identity = self.downsample(x)\n",
|
| 648 |
+
"\n",
|
| 649 |
+
" out += identity\n",
|
| 650 |
+
" out = self.relu(out)\n",
|
| 651 |
+
"\n",
|
| 652 |
+
" return out\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"class Bottleneck(nn.Module):\n",
|
| 655 |
+
" expansion = 4\n",
|
| 656 |
+
"\n",
|
| 657 |
+
" def __init__(\n",
|
| 658 |
+
" self,\n",
|
| 659 |
+
" inplanes: int,\n",
|
| 660 |
+
" planes: int,\n",
|
| 661 |
+
" stride: int = 1,\n",
|
| 662 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 663 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 664 |
+
" ) -> None:\n",
|
| 665 |
+
" super(Bottleneck, self).__init__()\n",
|
| 666 |
+
" self.conv1 = conv1x1(inplanes, planes)\n",
|
| 667 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 668 |
+
" self.conv2 = conv3x3(planes, planes, stride)\n",
|
| 669 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 670 |
+
" self.conv3 = conv1x1(planes, planes * self.expansion)\n",
|
| 671 |
+
" self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
|
| 672 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 673 |
+
" self.downsample = downsample\n",
|
| 674 |
+
" self.stride = stride\n",
|
| 675 |
+
"\n",
|
| 676 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 677 |
+
" identity = x\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" out = self.conv1(x)\n",
|
| 680 |
+
" out = self.bn1(out)\n",
|
| 681 |
+
" out = self.relu(out)\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" out = self.conv2(out)\n",
|
| 684 |
+
" out = self.bn2(out)\n",
|
| 685 |
+
" out = self.relu(out)\n",
|
| 686 |
+
"\n",
|
| 687 |
+
" out = self.conv3(out)\n",
|
| 688 |
+
" out = self.bn3(out)\n",
|
| 689 |
+
"\n",
|
| 690 |
+
" if self.downsample is not None:\n",
|
| 691 |
+
" identity = self.downsample(x)\n",
|
| 692 |
+
"\n",
|
| 693 |
+
" out += identity\n",
|
| 694 |
+
" out = self.relu(out)\n",
|
| 695 |
+
"\n",
|
| 696 |
+
" return out\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"class ComplexResNet(nn.Module):\n",
|
| 699 |
+
" def __init__(\n",
|
| 700 |
+
" self,\n",
|
| 701 |
+
" block: Type[Union[BasicBlock, Bottleneck]],\n",
|
| 702 |
+
" layers: List[int],\n",
|
| 703 |
+
" num_classes: int = STFT_LENGTH,\n",
|
| 704 |
+
" zero_init_residual: bool = False,\n",
|
| 705 |
+
" groups: int = 1,\n",
|
| 706 |
+
" width_per_group: int = 64,\n",
|
| 707 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 708 |
+
" ) -> None:\n",
|
| 709 |
+
" super(ComplexResNet, self).__init__()\n",
|
| 710 |
+
" if norm_layer is None:\n",
|
| 711 |
+
" norm_layer = cplx.ComplexBatchNorm2d\n",
|
| 712 |
+
" self._norm_layer = norm_layer\n",
|
| 713 |
+
"\n",
|
| 714 |
+
" self.inplanes = 64\n",
|
| 715 |
+
" self.dilation = 1\n",
|
| 716 |
+
"\n",
|
| 717 |
+
" self.groups = groups\n",
|
| 718 |
+
" self.base_width = width_per_group\n",
|
| 719 |
+
" self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
|
| 720 |
+
" self.bn1 = norm_layer(self.inplanes)\n",
|
| 721 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 722 |
+
" self.layer1 = self._make_layer(block, 64, layers[0])\n",
|
| 723 |
+
" self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
|
| 724 |
+
" self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
|
| 725 |
+
" self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
|
| 726 |
+
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
|
| 727 |
+
" self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
|
| 728 |
+
" self.sigmoid = cplx.ComplexSigmoid()\n",
|
| 729 |
+
"\n",
|
| 730 |
+
" def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
|
| 731 |
+
" norm_layer = self._norm_layer\n",
|
| 732 |
+
" downsample = None\n",
|
| 733 |
+
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
|
| 734 |
+
" downsample = nn.Sequential(\n",
|
| 735 |
+
" conv1x1(self.inplanes, planes * block.expansion, stride),\n",
|
| 736 |
+
" norm_layer(planes * block.expansion),\n",
|
| 737 |
+
" )\n",
|
| 738 |
+
"\n",
|
| 739 |
+
" layers = []\n",
|
| 740 |
+
" layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
|
| 741 |
+
" self.inplanes = planes * block.expansion\n",
|
| 742 |
+
" for _ in range(1, blocks):\n",
|
| 743 |
+
" layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
|
| 744 |
+
"\n",
|
| 745 |
+
" return nn.Sequential(*layers)\n",
|
| 746 |
+
"\n",
|
| 747 |
+
" def _forward_impl(self, x: Tensor) -> Tensor:\n",
|
| 748 |
+
" x = self.conv1(x)\n",
|
| 749 |
+
" x = self.bn1(x)\n",
|
| 750 |
+
" x = self.relu(x)\n",
|
| 751 |
+
"\n",
|
| 752 |
+
" x = self.layer1(x)\n",
|
| 753 |
+
" x = self.layer2(x)\n",
|
| 754 |
+
" x = self.layer3(x)\n",
|
| 755 |
+
" x = self.layer4(x)\n",
|
| 756 |
+
"\n",
|
| 757 |
+
" x = self.avgpool(x)\n",
|
| 758 |
+
" x = torch.flatten(x, 1)\n",
|
| 759 |
+
" x = self.fc(x)\n",
|
| 760 |
+
" x = self.sigmoid(x)\n",
|
| 761 |
+
" return x\n",
|
| 762 |
+
"\n",
|
| 763 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 764 |
+
" return self._forward_impl(x)\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"def ComplexResNet18():\n",
|
| 767 |
+
" return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"# Create the model instance\n",
|
| 770 |
+
"model = ComplexResNet18()\n",
|
| 771 |
+
"print(model)\n"
|
| 772 |
+
]
|
| 773 |
+
},
|
| 774 |
+
{
|
| 775 |
+
"cell_type": "markdown",
|
| 776 |
+
"id": "e4bc1b5d",
|
| 777 |
+
"metadata": {},
|
| 778 |
+
"source": [
|
| 779 |
+
"### Complex focal Loss"
|
| 780 |
+
]
|
| 781 |
+
},
|
| 782 |
+
{
|
| 783 |
+
"cell_type": "code",
|
| 784 |
+
"execution_count": null,
|
| 785 |
+
"id": "61c29429",
|
| 786 |
+
"metadata": {},
|
| 787 |
+
"outputs": [],
|
| 788 |
+
"source": [
|
| 789 |
+
"class ComplexFocalLoss(nn.Module):\n",
|
| 790 |
+
" def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
|
| 791 |
+
" super(ComplexFocalLoss, self).__init__()\n",
|
| 792 |
+
" self.alpha = alpha\n",
|
| 793 |
+
" self.gamma = gamma\n",
|
| 794 |
+
" self.reduction = reduction\n",
|
| 795 |
+
"\n",
|
| 796 |
+
" def forward(self, inputs, targets):\n",
|
| 797 |
+
" real_inputs = inputs.real\n",
|
| 798 |
+
" imag_inputs = inputs.imag\n",
|
| 799 |
+
" \n",
|
| 800 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
|
| 801 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
|
| 802 |
+
" \n",
|
| 803 |
+
" real_pt = torch.exp(-real_BCE_loss)\n",
|
| 804 |
+
" imag_pt = torch.exp(-imag_BCE_loss)\n",
|
| 805 |
+
" \n",
|
| 806 |
+
" real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
|
| 807 |
+
" imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
|
| 808 |
+
"\n",
|
| 809 |
+
" if self.reduction == 'mean':\n",
|
| 810 |
+
" return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
|
| 811 |
+
" elif self.reduction == 'sum':\n",
|
| 812 |
+
" return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
|
| 813 |
+
" else:\n",
|
| 814 |
+
" return real_F_loss + imag_F_loss"
|
| 815 |
+
]
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"cell_type": "markdown",
|
| 819 |
+
"id": "abb35ba2",
|
| 820 |
+
"metadata": {},
|
| 821 |
+
"source": [
|
| 822 |
+
"### Training with complex focal loss"
|
| 823 |
+
]
|
| 824 |
+
},
|
| 825 |
+
{
|
| 826 |
+
"cell_type": "code",
|
| 827 |
+
"execution_count": null,
|
| 828 |
+
"id": "86d7526b",
|
| 829 |
+
"metadata": {
|
| 830 |
+
"scrolled": false
|
| 831 |
+
},
|
| 832 |
+
"outputs": [],
|
| 833 |
+
"source": [
|
| 834 |
+
"# Initialize and train the ResNet-18 model\n",
|
| 835 |
+
"model = ComplexResNet18().to(device)\n",
|
| 836 |
+
"criterion = ComplexFocalLoss()\n",
|
| 837 |
+
"\n",
|
| 838 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
|
| 839 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 840 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 841 |
+
]
|
| 842 |
+
},
|
| 843 |
+
{
|
| 844 |
+
"cell_type": "markdown",
|
| 845 |
+
"id": "fd0c9d58",
|
| 846 |
+
"metadata": {},
|
| 847 |
+
"source": [
|
| 848 |
+
"### CVNN RV-BCE and CV-BCE Loss function implementation"
|
| 849 |
+
]
|
| 850 |
+
},
|
| 851 |
+
{
|
| 852 |
+
"cell_type": "code",
|
| 853 |
+
"execution_count": null,
|
| 854 |
+
"id": "99c736b8",
|
| 855 |
+
"metadata": {},
|
| 856 |
+
"outputs": [],
|
| 857 |
+
"source": [
|
| 858 |
+
"# CV BCE Loss Function Definition\n",
|
| 859 |
+
"class ComplexValuedBCELoss(nn.Module):\n",
|
| 860 |
+
" def __init__(self, reduction='mean'):\n",
|
| 861 |
+
" super(ComplexValuedBCELoss, self).__init__()\n",
|
| 862 |
+
" self.reduction = reduction\n",
|
| 863 |
+
"\n",
|
| 864 |
+
" def forward(self, inputs, targets):\n",
|
| 865 |
+
" real_inputs = inputs.real\n",
|
| 866 |
+
" imag_inputs = inputs.imag\n",
|
| 867 |
+
"\n",
|
| 868 |
+
" # Calculate binary cross-entropy for both real and imaginary parts\n",
|
| 869 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
|
| 870 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
|
| 871 |
+
" \n",
|
| 872 |
+
" # Combine the losses (you can adjust the weighting if necessary)\n",
|
| 873 |
+
" combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
|
| 874 |
+
" return combined_BCE_loss"
|
| 875 |
+
]
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
"cell_type": "markdown",
|
| 879 |
+
"id": "93d19ea7",
|
| 880 |
+
"metadata": {},
|
| 881 |
+
"source": [
|
| 882 |
+
"### CV-BCE Training"
|
| 883 |
+
]
|
| 884 |
+
},
|
| 885 |
+
{
|
| 886 |
+
"cell_type": "code",
|
| 887 |
+
"execution_count": null,
|
| 888 |
+
"id": "2c56d5b4",
|
| 889 |
+
"metadata": {
|
| 890 |
+
"scrolled": false
|
| 891 |
+
},
|
| 892 |
+
"outputs": [],
|
| 893 |
+
"source": [
|
| 894 |
+
"# Set the criterion for CV BCE\n",
|
| 895 |
+
"criterion = ComplexValuedBCELoss()\n",
|
| 896 |
+
"\n",
|
| 897 |
+
"# Train the ResNet-18 model with CV BCE\n",
|
| 898 |
+
"device = torch.device('cuda')\n",
|
| 899 |
+
"model = ComplexResNet18().to(device)\n",
|
| 900 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
| 901 |
+
"\n",
|
| 902 |
+
"# Start training with the previously defined train_model function\n",
|
| 903 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
|
| 904 |
+
" model, train_loader, valid_loader, criterion, \n",
|
| 905 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
|
| 906 |
+
")\n",
|
| 907 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 908 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 909 |
+
]
|
| 910 |
+
},
|
| 911 |
+
{
|
| 912 |
+
"cell_type": "markdown",
|
| 913 |
+
"id": "7ccd50ff",
|
| 914 |
+
"metadata": {},
|
| 915 |
+
"source": [
|
| 916 |
+
"### Save and Plot"
|
| 917 |
+
]
|
| 918 |
+
},
|
| 919 |
+
{
|
| 920 |
+
"cell_type": "code",
|
| 921 |
+
"execution_count": null,
|
| 922 |
+
"id": "eb41b92f",
|
| 923 |
+
"metadata": {},
|
| 924 |
+
"outputs": [],
|
| 925 |
+
"source": [
|
| 926 |
+
"import os\n",
|
| 927 |
+
"import json\n",
|
| 928 |
+
"import matplotlib.pyplot as plt\n",
|
| 929 |
+
"\n",
|
| 930 |
+
"# Define save directory\n",
|
| 931 |
+
"save_dir = 'CMuSeNet_results/segmentation_OTA'\n",
|
| 932 |
+
"\n",
|
| 933 |
+
"# Create the directory if it doesn't exist\n",
|
| 934 |
+
"os.makedirs(save_dir, exist_ok=True)\n",
|
| 935 |
+
"\n",
|
| 936 |
+
"# Plot training loss\n",
|
| 937 |
+
"plt.figure()\n",
|
| 938 |
+
"plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')\n",
|
| 939 |
+
"plt.title('Training Loss')\n",
|
| 940 |
+
"plt.xlabel('Epoch')\n",
|
| 941 |
+
"plt.ylabel('Loss')\n",
|
| 942 |
+
"plt.legend()\n",
|
| 943 |
+
"\n",
|
| 944 |
+
"# Save the training loss figure as PNG and SVG\n",
|
| 945 |
+
"plt.savefig(os.path.join(save_dir, 'training_loss.png'))\n",
|
| 946 |
+
"plt.savefig(os.path.join(save_dir, 'training_loss.svg'))\n",
|
| 947 |
+
"\n",
|
| 948 |
+
"# Show the training loss plot\n",
|
| 949 |
+
"plt.show()\n",
|
| 950 |
+
"\n",
|
| 951 |
+
"# Plot validation accuracy\n",
|
| 952 |
+
"plt.figure()\n",
|
| 953 |
+
"plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='green')\n",
|
| 954 |
+
"plt.title('Validation Accuracy')\n",
|
| 955 |
+
"plt.xlabel('Epoch')\n",
|
| 956 |
+
"plt.ylabel('Accuracy')\n",
|
| 957 |
+
"plt.legend()\n",
|
| 958 |
+
"\n",
|
| 959 |
+
"# Save the validation accuracy figure as PNG and SVG\n",
|
| 960 |
+
"plt.savefig(os.path.join(save_dir, 'validation_accuracy.png'))\n",
|
| 961 |
+
"plt.savefig(os.path.join(save_dir, 'validation_accuracy.svg'))\n",
|
| 962 |
+
"\n",
|
| 963 |
+
"# Show the validation accuracy plot\n",
|
| 964 |
+
"plt.show()\n",
|
| 965 |
+
"\n",
|
| 966 |
+
"# Save the actual data to a JSON file\n",
|
| 967 |
+
"results = {\n",
|
| 968 |
+
" \"train_losses\": train_losses,\n",
|
| 969 |
+
" \"val_accuracies\": val_accuracies,\n",
|
| 970 |
+
" \"epoch_durations\": epoch_durations,\n",
|
| 971 |
+
" \"combined_epoch_time\": combined_epoch_time\n",
|
| 972 |
+
"}\n",
|
| 973 |
+
"\n",
|
| 974 |
+
"# Save JSON file\n",
|
| 975 |
+
"with open(os.path.join(save_dir, 'training_validation_results.json'), 'w') as f:\n",
|
| 976 |
+
" json.dump(results, f)"
|
| 977 |
+
]
|
| 978 |
+
},
|
| 979 |
+
{
|
| 980 |
+
"cell_type": "markdown",
|
| 981 |
+
"id": "3a757949",
|
| 982 |
+
"metadata": {},
|
| 983 |
+
"source": [
|
| 984 |
+
"### Transfer Learning from Synthetic model"
|
| 985 |
+
]
|
| 986 |
+
},
|
| 987 |
+
{
|
| 988 |
+
"cell_type": "markdown",
|
| 989 |
+
"id": "ee265d28",
|
| 990 |
+
"metadata": {},
|
| 991 |
+
"source": [
|
| 992 |
+
"### Load pre-trained model"
|
| 993 |
+
]
|
| 994 |
+
},
|
| 995 |
+
{
|
| 996 |
+
"cell_type": "code",
|
| 997 |
+
"execution_count": null,
|
| 998 |
+
"id": "0dec6746",
|
| 999 |
+
"metadata": {},
|
| 1000 |
+
"outputs": [],
|
| 1001 |
+
"source": [
|
| 1002 |
+
"# Block to load pre-trained model and prepare for transfer learning\n",
|
| 1003 |
+
"device = torch.device(\"cuda\")\n",
|
| 1004 |
+
"\n",
|
| 1005 |
+
"# Load the pre-trained model\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
"model_path = \"path/to/model/save.pth\"\n",
|
| 1008 |
+
"model = ComplexResNet18().to(device)\n",
|
| 1009 |
+
"model.load_state_dict(torch.load(model_path, map_location=device))\n",
|
| 1010 |
+
"\n",
|
| 1011 |
+
"# Freeze all layers except the final layer\n",
|
| 1012 |
+
"for param in model.parameters():\n",
|
| 1013 |
+
" param.requires_grad = False\n",
|
| 1014 |
+
"\n",
|
| 1015 |
+
"# Modify the final layer for transfer learning (adjust `num_classes` as needed)\n",
|
| 1016 |
+
"num_classes = STFT_LENGTH # Set based on your current task\n",
|
| 1017 |
+
"model.fc = cplx.ComplexLinear(512 * BasicBlock.expansion, num_classes).to(device)\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
"# Unfreeze the final layer for training\n",
|
| 1020 |
+
"for param in model.fc.parameters():\n",
|
| 1021 |
+
" param.requires_grad = True\n"
|
| 1022 |
+
]
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"cell_type": "markdown",
|
| 1026 |
+
"id": "21e1e62b",
|
| 1027 |
+
"metadata": {},
|
| 1028 |
+
"source": [
|
| 1029 |
+
"### Complex Learning for Transfer Learning (Same as above but easier access)"
|
| 1030 |
+
]
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
"cell_type": "code",
|
| 1034 |
+
"execution_count": null,
|
| 1035 |
+
"id": "4c6656d0",
|
| 1036 |
+
"metadata": {},
|
| 1037 |
+
"outputs": [],
|
| 1038 |
+
"source": [
|
| 1039 |
+
"class ComplexFocalLoss(nn.Module):\n",
|
| 1040 |
+
" def __init__(self, alpha=0.5, gamma=2, reduction='mean'):\n",
|
| 1041 |
+
" super(ComplexFocalLoss, self).__init__()\n",
|
| 1042 |
+
" self.alpha = alpha\n",
|
| 1043 |
+
" self.gamma = gamma\n",
|
| 1044 |
+
" self.reduction = reduction\n",
|
| 1045 |
+
"\n",
|
| 1046 |
+
" def forward(self, inputs, targets):\n",
|
| 1047 |
+
" real_inputs = inputs.real\n",
|
| 1048 |
+
" imag_inputs = inputs.imag\n",
|
| 1049 |
+
" \n",
|
| 1050 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
|
| 1051 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
|
| 1052 |
+
" \n",
|
| 1053 |
+
" real_pt = torch.exp(-real_BCE_loss)\n",
|
| 1054 |
+
" imag_pt = torch.exp(-imag_BCE_loss)\n",
|
| 1055 |
+
" \n",
|
| 1056 |
+
" real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
|
| 1057 |
+
" imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
|
| 1058 |
+
"\n",
|
| 1059 |
+
" if self.reduction == 'mean':\n",
|
| 1060 |
+
" return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
|
| 1061 |
+
" elif self.reduction == 'sum':\n",
|
| 1062 |
+
" return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
|
| 1063 |
+
" else:\n",
|
| 1064 |
+
" return real_F_loss + imag_F_loss\n",
|
| 1065 |
+
"\n",
|
| 1066 |
+
"# Update the IoU calculation to handle complex values\n",
|
| 1067 |
+
"def calculate_iou(pred, target, threshold=0.5):\n",
|
| 1068 |
+
" real_pred = (pred.real > threshold).float()\n",
|
| 1069 |
+
" imag_pred = (pred.imag > threshold).float()\n",
|
| 1070 |
+
" \n",
|
| 1071 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
|
| 1072 |
+
" \n",
|
| 1073 |
+
" intersection = (combined_pred * target).sum(dim=1)\n",
|
| 1074 |
+
" union = (combined_pred + target).sum(dim=1) - intersection\n",
|
| 1075 |
+
" iou = (intersection / union).mean().item()\n",
|
| 1076 |
+
" return iou"
|
| 1077 |
+
]
|
| 1078 |
+
},
|
| 1079 |
+
{
|
| 1080 |
+
"cell_type": "markdown",
|
| 1081 |
+
"id": "bc9b7701",
|
| 1082 |
+
"metadata": {},
|
| 1083 |
+
"source": [
|
| 1084 |
+
"### Transfer Learning"
|
| 1085 |
+
]
|
| 1086 |
+
},
|
| 1087 |
+
{
|
| 1088 |
+
"cell_type": "code",
|
| 1089 |
+
"execution_count": null,
|
| 1090 |
+
"id": "c291a42e",
|
| 1091 |
+
"metadata": {
|
| 1092 |
+
"scrolled": false
|
| 1093 |
+
},
|
| 1094 |
+
"outputs": [],
|
| 1095 |
+
"source": [
|
| 1096 |
+
"# Define a new criterion and optimizer for fine-tuning\n",
|
| 1097 |
+
"# You may select between Focal Loss or BCE as your criterion\n",
|
| 1098 |
+
"#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
|
| 1099 |
+
"criterion = ComplexFocalLoss()\n",
|
| 1100 |
+
"# Use a smaller learning rate for fine-tuning\n",
|
| 1101 |
+
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
|
| 1102 |
+
"\n",
|
| 1103 |
+
"# Train the model (fine-tuning)\n",
|
| 1104 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
|
| 1105 |
+
" model, train_loader, valid_loader, criterion,\n",
|
| 1106 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
|
| 1107 |
+
")\n",
|
| 1108 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 1109 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 1110 |
+
]
|
| 1111 |
+
},
|
| 1112 |
+
{
|
| 1113 |
+
"cell_type": "markdown",
|
| 1114 |
+
"id": "98f81acc",
|
| 1115 |
+
"metadata": {},
|
| 1116 |
+
"source": [
|
| 1117 |
+
"## Transfer Transfer Learning (Different Radio)"
|
| 1118 |
+
]
|
| 1119 |
+
},
|
| 1120 |
+
{
|
| 1121 |
+
"cell_type": "code",
|
| 1122 |
+
"execution_count": null,
|
| 1123 |
+
"id": "55017794",
|
| 1124 |
+
"metadata": {},
|
| 1125 |
+
"outputs": [],
|
| 1126 |
+
"source": [
|
| 1127 |
+
"# Block to load pre-trained model and prepare for transfer learning\n",
|
| 1128 |
+
"device = torch.device(\"cuda\")\n",
|
| 1129 |
+
"\n",
|
| 1130 |
+
"model_path = \"/path/to/model/save.pth\"\n",
|
| 1131 |
+
"model = ComplexResNet18().to(device)\n",
|
| 1132 |
+
"#model = ComplexValuedBCELoss().to(device)\n",
|
| 1133 |
+
"model.load_state_dict(torch.load(model_path, map_location=device))\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
"# Freeze all layers except the final layer\n",
|
| 1136 |
+
"for param in model.parameters():\n",
|
| 1137 |
+
" param.requires_grad = False\n",
|
| 1138 |
+
"\n",
|
| 1139 |
+
"# Modify the final layer for transfer learning (adjust `num_classes` as needed)\n",
|
| 1140 |
+
"num_classes = STFT_LENGTH # Set based on your current task\n",
|
| 1141 |
+
"model.fc = cplx.ComplexLinear(512 * BasicBlock.expansion, num_classes).to(device)\n",
|
| 1142 |
+
"\n",
|
| 1143 |
+
"# Unfreeze the final layer for training\n",
|
| 1144 |
+
"for param in model.fc.parameters():\n",
|
| 1145 |
+
" param.requires_grad = True\n"
|
| 1146 |
+
]
|
| 1147 |
+
},
|
| 1148 |
+
{
|
| 1149 |
+
"cell_type": "code",
|
| 1150 |
+
"execution_count": null,
|
| 1151 |
+
"id": "5933b01f",
|
| 1152 |
+
"metadata": {},
|
| 1153 |
+
"outputs": [],
|
| 1154 |
+
"source": [
|
| 1155 |
+
"# Define a new criterion and optimizer for fine-tuning\n",
|
| 1156 |
+
"# You may select between Focal Loss or BCE as your criterion\n",
|
| 1157 |
+
"#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
|
| 1158 |
+
"criterion = ComplexFocalLoss()\n",
|
| 1159 |
+
"# Use a smaller learning rate for fine-tuning\n",
|
| 1160 |
+
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
|
| 1161 |
+
"\n",
|
| 1162 |
+
"# Train the model (fine-tuning)\n",
|
| 1163 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
|
| 1164 |
+
" model, train_loader, valid_loader, criterion,\n",
|
| 1165 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=1\n",
|
| 1166 |
+
")\n",
|
| 1167 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 1168 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 1169 |
+
]
|
| 1170 |
+
},
|
| 1171 |
+
{
|
| 1172 |
+
"cell_type": "markdown",
|
| 1173 |
+
"id": "dbef8bad",
|
| 1174 |
+
"metadata": {},
|
| 1175 |
+
"source": [
|
| 1176 |
+
"### Evaluation CVNN OTA"
|
| 1177 |
+
]
|
| 1178 |
+
},
|
| 1179 |
+
{
|
| 1180 |
+
"cell_type": "code",
|
| 1181 |
+
"execution_count": null,
|
| 1182 |
+
"id": "d0ac03b7",
|
| 1183 |
+
"metadata": {},
|
| 1184 |
+
"outputs": [],
|
| 1185 |
+
"source": [
|
| 1186 |
+
"import torch\n",
|
| 1187 |
+
"from tqdm import tqdm\n",
|
| 1188 |
+
"import numpy as np\n",
|
| 1189 |
+
"from collections import defaultdict\n",
|
| 1190 |
+
"import torch.nn.functional as F\n",
|
| 1191 |
+
"from scipy.optimize import linear_sum_assignment\n",
|
| 1192 |
+
"from torch.utils.data import ConcatDataset"
|
| 1193 |
+
]
|
| 1194 |
+
},
|
| 1195 |
+
{
|
| 1196 |
+
"cell_type": "code",
|
| 1197 |
+
"execution_count": null,
|
| 1198 |
+
"id": "f831e874",
|
| 1199 |
+
"metadata": {},
|
| 1200 |
+
"outputs": [],
|
| 1201 |
+
"source": [
|
| 1202 |
+
"device = \"cuda\"\n",
|
| 1203 |
+
"\n",
|
| 1204 |
+
"model_path = \"/path/to/model/save.pth\"\n",
|
| 1205 |
+
"model = ComplexResNet18().to(device)\n",
|
| 1206 |
+
"model.load_state_dict(torch.load(model_path, map_location=device))\n",
|
| 1207 |
+
"model.eval()"
|
| 1208 |
+
]
|
| 1209 |
+
},
|
| 1210 |
+
{
|
| 1211 |
+
"cell_type": "code",
|
| 1212 |
+
"execution_count": null,
|
| 1213 |
+
"id": "a303080e",
|
| 1214 |
+
"metadata": {},
|
| 1215 |
+
"outputs": [],
|
| 1216 |
+
"source": [
|
| 1217 |
+
"# Load the pre-trained model for evaluation\n",
|
| 1218 |
+
"\n",
|
| 1219 |
+
"full_dataset = ConcatDataset([\n",
|
| 1220 |
+
" WidebandSignalDataset(signal_ids=train, return_snrs=True),\n",
|
| 1221 |
+
" WidebandSignalDataset(signal_ids=validation, return_snrs=True),\n",
|
| 1222 |
+
" WidebandSignalDataset(signal_ids=test, return_snrs=True)\n",
|
| 1223 |
+
"])\n",
|
| 1224 |
+
"full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)"
|
| 1225 |
+
]
|
| 1226 |
+
},
|
| 1227 |
+
{
|
| 1228 |
+
"cell_type": "markdown",
|
| 1229 |
+
"id": "ad326f1d",
|
| 1230 |
+
"metadata": {},
|
| 1231 |
+
"source": [
|
| 1232 |
+
"### Function initialization"
|
| 1233 |
+
]
|
| 1234 |
+
},
|
| 1235 |
+
{
|
| 1236 |
+
"cell_type": "code",
|
| 1237 |
+
"execution_count": null,
|
| 1238 |
+
"id": "00d0228c",
|
| 1239 |
+
"metadata": {},
|
| 1240 |
+
"outputs": [],
|
| 1241 |
+
"source": [
|
| 1242 |
+
"def expand_true(array, distance=1):\n",
|
| 1243 |
+
" # Create kernel of appropriate size\n",
|
| 1244 |
+
" kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
|
| 1245 |
+
" array = array.unsqueeze(1).float() # Add channel dimension\n",
|
| 1246 |
+
" result = F.conv1d(array, kernel, padding=distance)\n",
|
| 1247 |
+
" result = result.squeeze(1) # Remove the extra dimension\n",
|
| 1248 |
+
" return result > 0\n",
|
| 1249 |
+
"def reshape_to_2d(data):\n",
|
| 1250 |
+
" return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]\n",
|
| 1251 |
+
"def get_true_groups(tensor, device):\n",
|
| 1252 |
+
" assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
|
| 1253 |
+
" all_groups = []\n",
|
| 1254 |
+
" for i in range(tensor.size(0)):\n",
|
| 1255 |
+
" item = tensor[i]\n",
|
| 1256 |
+
" item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
|
| 1257 |
+
" diffs = item.float().diff()\n",
|
| 1258 |
+
" starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
|
| 1259 |
+
" ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
|
| 1260 |
+
" groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
|
| 1261 |
+
" all_groups.append(groups)\n",
|
| 1262 |
+
" return all_groups\n",
|
| 1263 |
+
"\n",
|
| 1264 |
+
"def calculate_iou(box1, box2):\n",
|
| 1265 |
+
" intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
|
| 1266 |
+
" union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
|
| 1267 |
+
" return intersection / union if union != 0 else 0\n",
|
| 1268 |
+
"\n",
|
| 1269 |
+
"def match_targets(targets, preds):\n",
|
| 1270 |
+
" ious = []\n",
|
| 1271 |
+
" for target in targets:\n",
|
| 1272 |
+
" iou_targets = []\n",
|
| 1273 |
+
" for pred in preds:\n",
|
| 1274 |
+
" iou_targets.append(calculate_iou(target, pred))\n",
|
| 1275 |
+
" ious.append(iou_targets)\n",
|
| 1276 |
+
" cost_matrix = np.array(ious)\n",
|
| 1277 |
+
" row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
|
| 1278 |
+
" return row_ind, col_ind\n",
|
| 1279 |
+
"\n",
|
| 1280 |
+
"def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
|
| 1281 |
+
" ious = [0 for _ in target_boxes]\n",
|
| 1282 |
+
" matching_dict = dict(zip(*matching))\n",
|
| 1283 |
+
" for target_index, target_box in enumerate(target_boxes):\n",
|
| 1284 |
+
" if target_index in matching_dict:\n",
|
| 1285 |
+
" pred_index = matching_dict[target_index]\n",
|
| 1286 |
+
" if pred_index < len(prediction_boxes):\n",
|
| 1287 |
+
" box1 = target_box\n",
|
| 1288 |
+
" box2 = prediction_boxes[pred_index]\n",
|
| 1289 |
+
" ious[target_index] = calculate_iou(box1, box2)\n",
|
| 1290 |
+
" return ious\n",
|
| 1291 |
+
"def model_predictor(signals):\n",
|
| 1292 |
+
" # Convert signals to complex tensors\n",
|
| 1293 |
+
" if signals.dtype != torch.complex64 and signals.dtype != torch.complex128:\n",
|
| 1294 |
+
" signals = signals.type(torch.complex64)\n",
|
| 1295 |
+
" # Reshape the input signals to the expected shape\n",
|
| 1296 |
+
" signals = reshape_to_2d(signals)\n",
|
| 1297 |
+
" signals = signals.to(device)\n",
|
| 1298 |
+
" # Use the already loaded model and apply thresholding\n",
|
| 1299 |
+
" with torch.no_grad():\n",
|
| 1300 |
+
" outputs = model(signals)\n",
|
| 1301 |
+
" # Handle complex outputs appropriately\n",
|
| 1302 |
+
" real_outputs = outputs.real\n",
|
| 1303 |
+
" imag_outputs = outputs.imag\n",
|
| 1304 |
+
" real_pred = (real_outputs > 0.5)\n",
|
| 1305 |
+
" imag_pred = (imag_outputs > 0.5)\n",
|
| 1306 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred)\n",
|
| 1307 |
+
" return expand_true(combined_pred.float())\n",
|
| 1308 |
+
"\n",
|
| 1309 |
+
"# Complex IoU Implementation\n",
|
| 1310 |
+
"def calculate_complex_iou(box1_real, box1_imag, box2_real, box2_imag):\n",
|
| 1311 |
+
" # Calculate real component intersection\n",
|
| 1312 |
+
" real_intersection = max(0, min(box1_real[1], box2_real[1]) - max(box1_real[0], box2_real[0]))\n",
|
| 1313 |
+
" real_union = max(box1_real[1], box2_real[1]) - min(box1_real[0], box2_real[0])\n",
|
| 1314 |
+
" \n",
|
| 1315 |
+
" # Calculate imaginary component intersection\n",
|
| 1316 |
+
" imag_intersection = max(0, min(box1_imag[1], box2_imag[1]) - max(box1_imag[0], box2_imag[0]))\n",
|
| 1317 |
+
" imag_union = max(box1_imag[1], box2_imag[1]) - min(box1_imag[0], box2_imag[0])\n",
|
| 1318 |
+
" \n",
|
| 1319 |
+
" # Combine intersections and unions\n",
|
| 1320 |
+
" total_intersection = real_intersection + imag_intersection\n",
|
| 1321 |
+
" total_union = real_union + imag_union\n",
|
| 1322 |
+
" \n",
|
| 1323 |
+
" # Return IoU\n",
|
| 1324 |
+
" return total_intersection / total_union if total_union != 0 else 0\n",
|
| 1325 |
+
"\n",
|
| 1326 |
+
"def match_complex_targets(targets_real, targets_imag, preds_real, preds_imag):\n",
|
| 1327 |
+
" ious = []\n",
|
| 1328 |
+
" for target_real, target_imag in zip(targets_real, targets_imag):\n",
|
| 1329 |
+
" iou_targets = []\n",
|
| 1330 |
+
" for pred_real, pred_imag in zip(preds_real, preds_imag):\n",
|
| 1331 |
+
" iou_targets.append(calculate_complex_iou(target_real, target_imag, pred_real, pred_imag))\n",
|
| 1332 |
+
" ious.append(iou_targets)\n",
|
| 1333 |
+
" cost_matrix = np.array(ious)\n",
|
| 1334 |
+
" row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
|
| 1335 |
+
" return row_ind, col_ind\n",
|
| 1336 |
+
"\n",
|
| 1337 |
+
"def calculate_matched_complex_ious(target_boxes_real, target_boxes_imag, \n",
|
| 1338 |
+
" prediction_boxes_real, prediction_boxes_imag, matching):\n",
|
| 1339 |
+
" ious = [0 for _ in target_boxes_real]\n",
|
| 1340 |
+
" matching_dict = dict(zip(*matching))\n",
|
| 1341 |
+
" for target_index, (target_box_real, target_box_imag) in enumerate(zip(target_boxes_real, target_boxes_imag)):\n",
|
| 1342 |
+
" if target_index in matching_dict:\n",
|
| 1343 |
+
" pred_index = matching_dict[target_index]\n",
|
| 1344 |
+
" if pred_index < len(prediction_boxes_real):\n",
|
| 1345 |
+
" box1_real, box1_imag = target_box_real, target_box_imag\n",
|
| 1346 |
+
" box2_real, box2_imag = prediction_boxes_real[pred_index], prediction_boxes_imag[pred_index]\n",
|
| 1347 |
+
" ious[target_index] = calculate_complex_iou(box1_real, box1_imag, box2_real, box2_imag)\n",
|
| 1348 |
+
" return ious\n"
|
| 1349 |
+
]
|
| 1350 |
+
},
|
| 1351 |
+
{
|
| 1352 |
+
"cell_type": "markdown",
|
| 1353 |
+
"id": "c114c7a2",
|
| 1354 |
+
"metadata": {},
|
| 1355 |
+
"source": [
|
| 1356 |
+
"### Evaluate function"
|
| 1357 |
+
]
|
| 1358 |
+
},
|
| 1359 |
+
{
|
| 1360 |
+
"cell_type": "code",
|
| 1361 |
+
"execution_count": null,
|
| 1362 |
+
"id": "41f12e83",
|
| 1363 |
+
"metadata": {},
|
| 1364 |
+
"outputs": [],
|
| 1365 |
+
"source": [
|
| 1366 |
+
"def evaluate(predictor, data_loader, device=\"cuda\"):\n",
|
| 1367 |
+
" iou_thresholds = [0.5, 0.7, 0.9]\n",
|
| 1368 |
+
" snr_metrics = defaultdict(lambda: {\n",
|
| 1369 |
+
" \"iou_sum\": 0.0,\n",
|
| 1370 |
+
" \"iou_count\": 0,\n",
|
| 1371 |
+
" \"recall_counts\": defaultdict(int),\n",
|
| 1372 |
+
" \"total_samples\": defaultdict(int),\n",
|
| 1373 |
+
" \"correct_pixels\": 0,\n",
|
| 1374 |
+
" \"total_pixels\": 0\n",
|
| 1375 |
+
" })\n",
|
| 1376 |
+
" total_iou_sum, total_iou_count = 0.0, 0\n",
|
| 1377 |
+
" total_correct_pixels, total_total_pixels = 0, 0\n",
|
| 1378 |
+
" total_recall_counts = defaultdict(int)\n",
|
| 1379 |
+
" total_samples = defaultdict(int)\n",
|
| 1380 |
+
"\n",
|
| 1381 |
+
" for batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
|
| 1382 |
+
" if len(batch) == 3:\n",
|
| 1383 |
+
" inputs, masks, snrs_in_batch = batch\n",
|
| 1384 |
+
" else:\n",
|
| 1385 |
+
" inputs, masks = batch\n",
|
| 1386 |
+
" snrs_in_batch = [0] * len(inputs) # Default SNR if not provided\n",
|
| 1387 |
+
"\n",
|
| 1388 |
+
" inputs = inputs.to(device)\n",
|
| 1389 |
+
" masks = masks.to(device)\n",
|
| 1390 |
+
" outputs = predictor(inputs)\n",
|
| 1391 |
+
"\n",
|
| 1392 |
+
" for i in range(len(inputs)):\n",
|
| 1393 |
+
" mask = masks[i]\n",
|
| 1394 |
+
" output = outputs[i]\n",
|
| 1395 |
+
"\n",
|
| 1396 |
+
" # Resize output to match mask shape if necessary\n",
|
| 1397 |
+
" if output.numel() != mask.numel():\n",
|
| 1398 |
+
" output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
|
| 1399 |
+
"\n",
|
| 1400 |
+
" thresholded_output = (output >= 0.5).float()\n",
|
| 1401 |
+
"\n",
|
| 1402 |
+
" correct_pixels = (thresholded_output == mask).sum().item()\n",
|
| 1403 |
+
" total_pixels = mask.numel()\n",
|
| 1404 |
+
" total_correct_pixels += correct_pixels\n",
|
| 1405 |
+
" total_total_pixels += total_pixels\n",
|
| 1406 |
+
"\n",
|
| 1407 |
+
" # Get SNR value and round it to the nearest integer\n",
|
| 1408 |
+
" snr = snrs_in_batch[i]\n",
|
| 1409 |
+
" if isinstance(snr, torch.Tensor):\n",
|
| 1410 |
+
" snr = snr.item()\n",
|
| 1411 |
+
" snr = int(round(snr)) # Round SNR to the nearest integer\n",
|
| 1412 |
+
"\n",
|
| 1413 |
+
" snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
|
| 1414 |
+
" snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
|
| 1415 |
+
"\n",
|
| 1416 |
+
" target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
|
| 1417 |
+
" pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
|
| 1418 |
+
" if not target_boxes or not pred_boxes:\n",
|
| 1419 |
+
" continue\n",
|
| 1420 |
+
" matching = match_targets(target_boxes, pred_boxes)\n",
|
| 1421 |
+
" matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
|
| 1422 |
+
"\n",
|
| 1423 |
+
" snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
|
| 1424 |
+
" snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
|
| 1425 |
+
" total_iou_sum += sum(matched_ious)\n",
|
| 1426 |
+
" total_iou_count += len(matched_ious)\n",
|
| 1427 |
+
"\n",
|
| 1428 |
+
" for th in iou_thresholds:\n",
|
| 1429 |
+
" true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
|
| 1430 |
+
" snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
|
| 1431 |
+
" snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
|
| 1432 |
+
" total_recall_counts[th] += true_positives\n",
|
| 1433 |
+
" total_samples[th] += len(target_boxes)\n",
|
| 1434 |
+
"\n",
|
| 1435 |
+
" # Calculate overall metrics\n",
|
| 1436 |
+
" overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
|
| 1437 |
+
" overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
|
| 1438 |
+
" overall_recall = {\n",
|
| 1439 |
+
" th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0\n",
|
| 1440 |
+
" for th in iou_thresholds\n",
|
| 1441 |
+
" }\n",
|
| 1442 |
+
"\n",
|
| 1443 |
+
" # Print overall results\n",
|
| 1444 |
+
" print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
|
| 1445 |
+
" print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
|
| 1446 |
+
" for th in iou_thresholds:\n",
|
| 1447 |
+
" print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
|
| 1448 |
+
"\n",
|
| 1449 |
+
" # Print per-SNR results\n",
|
| 1450 |
+
" for snr in sorted(snr_metrics.keys()):\n",
|
| 1451 |
+
" metrics = snr_metrics[snr]\n",
|
| 1452 |
+
" snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
|
| 1453 |
+
" snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
|
| 1454 |
+
" print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
|
| 1455 |
+
" print(f\" IoU: {snr_iou:.4f}\")\n",
|
| 1456 |
+
" for th in iou_thresholds:\n",
|
| 1457 |
+
" recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
|
| 1458 |
+
" print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
|
| 1459 |
+
"\n",
|
| 1460 |
+
" return snr_metrics\n"
|
| 1461 |
+
]
|
| 1462 |
+
},
|
| 1463 |
+
{
|
| 1464 |
+
"cell_type": "code",
|
| 1465 |
+
"execution_count": null,
|
| 1466 |
+
"id": "0d2fd13f",
|
| 1467 |
+
"metadata": {
|
| 1468 |
+
"scrolled": false
|
| 1469 |
+
},
|
| 1470 |
+
"outputs": [],
|
| 1471 |
+
"source": [
|
| 1472 |
+
"# Run evaluation on the full dataset\n",
|
| 1473 |
+
"snr_metrics = evaluate(model_predictor, full_loader, device=device)"
|
| 1474 |
+
]
|
| 1475 |
+
},
|
| 1476 |
+
{
|
| 1477 |
+
"cell_type": "markdown",
|
| 1478 |
+
"id": "07eade04",
|
| 1479 |
+
"metadata": {},
|
| 1480 |
+
"source": [
|
| 1481 |
+
"### Save and Plot"
|
| 1482 |
+
]
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"cell_type": "code",
|
| 1486 |
+
"execution_count": null,
|
| 1487 |
+
"id": "bc84b73a",
|
| 1488 |
+
"metadata": {},
|
| 1489 |
+
"outputs": [],
|
| 1490 |
+
"source": [
|
| 1491 |
+
"import os\n",
|
| 1492 |
+
"import json\n",
|
| 1493 |
+
"import matplotlib.pyplot as plt\n",
|
| 1494 |
+
"\n",
|
| 1495 |
+
"def save_results_and_plot(snr_metrics, save_path):\n",
|
| 1496 |
+
" \"\"\"\n",
|
| 1497 |
+
" Saves evaluation results to a JSON file and generates plots for Accuracy, IoU, and Recall vs. SNR.\n",
|
| 1498 |
+
" Sets x-axis limits to range from -9 dB to 12 dB to eliminate blank space on the right.\n",
|
| 1499 |
+
"\n",
|
| 1500 |
+
" Args:\n",
|
| 1501 |
+
" snr_metrics (dict): The evaluation results obtained from the evaluate function.\n",
|
| 1502 |
+
" save_path (str): The directory path where results and plots will be saved.\n",
|
| 1503 |
+
"\n",
|
| 1504 |
+
" Outputs:\n",
|
| 1505 |
+
" - evaluation_results.json\n",
|
| 1506 |
+
" - accuracy_vs_snr.png and .svg\n",
|
| 1507 |
+
" - iou_vs_snr.png and .svg\n",
|
| 1508 |
+
" - recall_vs_snr.png and .svg\n",
|
| 1509 |
+
" \"\"\"\n",
|
| 1510 |
+
" # Ensure the directory exists\n",
|
| 1511 |
+
" os.makedirs(save_path, exist_ok=True)\n",
|
| 1512 |
+
" \n",
|
| 1513 |
+
" # Extract data from snr_metrics\n",
|
| 1514 |
+
" snr_list = sorted(snr_metrics.keys())\n",
|
| 1515 |
+
" accuracy_list = []\n",
|
| 1516 |
+
" iou_list = []\n",
|
| 1517 |
+
" recall_05 = []\n",
|
| 1518 |
+
" recall_07 = []\n",
|
| 1519 |
+
" recall_09 = []\n",
|
| 1520 |
+
" \n",
|
| 1521 |
+
" # Prepare data for JSON serialization\n",
|
| 1522 |
+
" json_data = {}\n",
|
| 1523 |
+
" \n",
|
| 1524 |
+
" for snr in snr_list:\n",
|
| 1525 |
+
" metrics = snr_metrics[snr]\n",
|
| 1526 |
+
" snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
|
| 1527 |
+
" snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
|
| 1528 |
+
" recall_at_05 = metrics[\"recall_counts\"][0.5] / metrics[\"total_samples\"][0.5] if metrics[\"total_samples\"][0.5] > 0 else 0\n",
|
| 1529 |
+
" recall_at_07 = metrics[\"recall_counts\"][0.7] / metrics[\"total_samples\"][0.7] if metrics[\"total_samples\"][0.7] > 0 else 0\n",
|
| 1530 |
+
" recall_at_09 = metrics[\"recall_counts\"][0.9] / metrics[\"total_samples\"][0.9] if metrics[\"total_samples\"][0.9] > 0 else 0\n",
|
| 1531 |
+
"\n",
|
| 1532 |
+
" # Append to lists for plotting\n",
|
| 1533 |
+
" accuracy_list.append(snr_accuracy)\n",
|
| 1534 |
+
" iou_list.append(snr_iou)\n",
|
| 1535 |
+
" recall_05.append(recall_at_05)\n",
|
| 1536 |
+
" recall_07.append(recall_at_07)\n",
|
| 1537 |
+
" recall_09.append(recall_at_09)\n",
|
| 1538 |
+
"\n",
|
| 1539 |
+
" # Prepare data for JSON\n",
|
| 1540 |
+
" json_data[snr] = {\n",
|
| 1541 |
+
" \"accuracy\": snr_accuracy,\n",
|
| 1542 |
+
" \"iou\": snr_iou,\n",
|
| 1543 |
+
" \"recall\": {\n",
|
| 1544 |
+
" \"0.5\": recall_at_05,\n",
|
| 1545 |
+
" \"0.7\": recall_at_07,\n",
|
| 1546 |
+
" \"0.9\": recall_at_09,\n",
|
| 1547 |
+
" }\n",
|
| 1548 |
+
" }\n",
|
| 1549 |
+
" \n",
|
| 1550 |
+
" # Save json_data to JSON file\n",
|
| 1551 |
+
" json_file_path = os.path.join(save_path, 'evaluation_results.json')\n",
|
| 1552 |
+
" with open(json_file_path, 'w') as json_file:\n",
|
| 1553 |
+
" json.dump(json_data, json_file, indent=4)\n",
|
| 1554 |
+
" \n",
|
| 1555 |
+
" # Plot Accuracy vs. SNR\n",
|
| 1556 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1557 |
+
" plt.plot(snr_list, accuracy_list, marker='o', label='Accuracy')\n",
|
| 1558 |
+
" plt.title('Accuracy vs. SNR')\n",
|
| 1559 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1560 |
+
" plt.ylabel('Accuracy (%)')\n",
|
| 1561 |
+
" plt.grid(True)\n",
|
| 1562 |
+
" plt.legend()\n",
|
| 1563 |
+
" \n",
|
| 1564 |
+
" # Set x-axis limits\n",
|
| 1565 |
+
" #plt.xlim(-9, 12)\n",
|
| 1566 |
+
" plt.xlim(-16, 16)\n",
|
| 1567 |
+
" # Save the plot\n",
|
| 1568 |
+
" accuracy_png_path = os.path.join(save_path, 'accuracy_vs_snr.png')\n",
|
| 1569 |
+
" accuracy_svg_path = os.path.join(save_path, 'accuracy_vs_snr.svg')\n",
|
| 1570 |
+
" plt.savefig(accuracy_png_path, format='png', bbox_inches='tight')\n",
|
| 1571 |
+
" plt.savefig(accuracy_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1572 |
+
" \n",
|
| 1573 |
+
" plt.show()\n",
|
| 1574 |
+
" plt.close()\n",
|
| 1575 |
+
" \n",
|
| 1576 |
+
" # Plot IoU vs. SNR\n",
|
| 1577 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1578 |
+
" plt.plot(snr_list, iou_list, marker='o', color='orange', label='IoU')\n",
|
| 1579 |
+
" plt.title('IoU vs. SNR')\n",
|
| 1580 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1581 |
+
" plt.ylabel('IoU')\n",
|
| 1582 |
+
" plt.grid(True)\n",
|
| 1583 |
+
" plt.legend()\n",
|
| 1584 |
+
" \n",
|
| 1585 |
+
" # Set x-axis limits\n",
|
| 1586 |
+
" #plt.xlim(-9, 12)\n",
|
| 1587 |
+
" plt.xlim(-16, 16)\n",
|
| 1588 |
+
" # Save the plot\n",
|
| 1589 |
+
" iou_png_path = os.path.join(save_path, 'iou_vs_snr.png')\n",
|
| 1590 |
+
" iou_svg_path = os.path.join(save_path, 'iou_vs_snr.svg')\n",
|
| 1591 |
+
" plt.savefig(iou_png_path, format='png', bbox_inches='tight')\n",
|
| 1592 |
+
" plt.savefig(iou_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1593 |
+
" \n",
|
| 1594 |
+
" plt.show()\n",
|
| 1595 |
+
" plt.close()\n",
|
| 1596 |
+
" \n",
|
| 1597 |
+
" # Plot Recall at Different IoU Thresholds vs. SNR\n",
|
| 1598 |
+
" plt.figure(figsize=(10, 6))\n",
|
| 1599 |
+
" plt.plot(snr_list, recall_05, marker='o', label='Recall @ IoU 0.5')\n",
|
| 1600 |
+
" plt.plot(snr_list, recall_07, marker='s', label='Recall @ IoU 0.7')\n",
|
| 1601 |
+
" plt.plot(snr_list, recall_09, marker='^', label='Recall @ IoU 0.9')\n",
|
| 1602 |
+
" plt.title('Recall at Different IoU Thresholds vs. SNR')\n",
|
| 1603 |
+
" plt.xlabel('SNR (dB)')\n",
|
| 1604 |
+
" plt.ylabel('Recall')\n",
|
| 1605 |
+
" plt.grid(True)\n",
|
| 1606 |
+
" plt.legend()\n",
|
| 1607 |
+
" \n",
|
| 1608 |
+
" # Set x-axis limits\n",
|
| 1609 |
+
" plt.xlim(-9, 12)\n",
|
| 1610 |
+
" \n",
|
| 1611 |
+
" # Save the plot\n",
|
| 1612 |
+
" recall_png_path = os.path.join(save_path, 'recall_vs_snr.png')\n",
|
| 1613 |
+
" recall_svg_path = os.path.join(save_path, 'recall_vs_snr.svg')\n",
|
| 1614 |
+
" plt.savefig(recall_png_path, format='png', bbox_inches='tight')\n",
|
| 1615 |
+
" plt.savefig(recall_svg_path, format='svg', bbox_inches='tight')\n",
|
| 1616 |
+
" \n",
|
| 1617 |
+
" plt.show()\n",
|
| 1618 |
+
" plt.close()\n"
|
| 1619 |
+
]
|
| 1620 |
+
},
|
| 1621 |
+
{
|
| 1622 |
+
"cell_type": "code",
|
| 1623 |
+
"execution_count": null,
|
| 1624 |
+
"id": "1974e70d",
|
| 1625 |
+
"metadata": {
|
| 1626 |
+
"scrolled": false
|
| 1627 |
+
},
|
| 1628 |
+
"outputs": [],
|
| 1629 |
+
"source": [
|
| 1630 |
+
"save_path = 'CMuSeNet_results/OTA'\n",
|
| 1631 |
+
"\n",
|
| 1632 |
+
"# Save results and generate plots\n",
|
| 1633 |
+
"save_results_and_plot(snr_metrics, save_path)"
|
| 1634 |
+
]
|
| 1635 |
+
}
|
| 1636 |
+
],
|
| 1637 |
+
"metadata": {
|
| 1638 |
+
"kernelspec": {
|
| 1639 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1640 |
+
"language": "python",
|
| 1641 |
+
"name": "python3"
|
| 1642 |
+
},
|
| 1643 |
+
"language_info": {
|
| 1644 |
+
"codemirror_mode": {
|
| 1645 |
+
"name": "ipython",
|
| 1646 |
+
"version": 3
|
| 1647 |
+
},
|
| 1648 |
+
"file_extension": ".py",
|
| 1649 |
+
"mimetype": "text/x-python",
|
| 1650 |
+
"name": "python",
|
| 1651 |
+
"nbconvert_exporter": "python",
|
| 1652 |
+
"pygments_lexer": "ipython3",
|
| 1653 |
+
"version": "3.10.9"
|
| 1654 |
+
}
|
| 1655 |
+
},
|
| 1656 |
+
"nbformat": 4,
|
| 1657 |
+
"nbformat_minor": 5
|
| 1658 |
+
}
|
CMuSeNet_Synthetic.ipynb
ADDED
|
@@ -0,0 +1,1241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b5007b71",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"### Initialization"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "3e6b1226",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"### Initialization block\n",
|
| 19 |
+
"from pathlib import Path\n",
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"import json\n",
|
| 22 |
+
"import torch\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"from tqdm import tqdm\n",
|
| 25 |
+
"import math\n",
|
| 26 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"STFT_LENGTH = 16 * 1024\n",
|
| 29 |
+
"DATA_DIR = Path(\"dataset/\")\n",
|
| 30 |
+
"SAMPLE_RATE = 20e6\n",
|
| 31 |
+
"MODULATIONS = [\"QPSK\", \"BPSK\", \"8-PSK\", \"8-QAM\", \"16-QAM\", \"GMSK\", \"2-FSK\"]\n",
|
| 32 |
+
"MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}\n",
|
| 33 |
+
"NUMBER_OF_MODULATIONS = len(MODULATIONS)\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"def load_data(snr, name, load_metadata_only=False):\n",
|
| 36 |
+
" if not load_metadata_only:\n",
|
| 37 |
+
" with open(DATA_DIR/str(snr)/str(name)/\"data.dat\", \"rb\") as f:\n",
|
| 38 |
+
" signal = np.fromfile(f, dtype=np.complex128)\n",
|
| 39 |
+
" else:\n",
|
| 40 |
+
" signal = None\n",
|
| 41 |
+
" with open(DATA_DIR/str(snr)/str(name)/\"meta-data.json\") as f:\n",
|
| 42 |
+
" meta = json.load(f)\n",
|
| 43 |
+
" if type(meta) == dict:\n",
|
| 44 |
+
" meta = [meta]\n",
|
| 45 |
+
" return signal, meta\n",
|
| 46 |
+
"\n",
|
| 47 |
+
" \n",
|
| 48 |
+
"def _get_all_numbered_dirs(root_dir):\n",
|
| 49 |
+
" dirs = []\n",
|
| 50 |
+
" for directory in root_dir.iterdir():\n",
|
| 51 |
+
" dirs.append(int(directory.name))\n",
|
| 52 |
+
" dirs.sort()\n",
|
| 53 |
+
" return dirs\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"def get_signals(snr):\n",
|
| 56 |
+
" return _get_all_numbered_dirs(Path(DATA_DIR)/str(snr))\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"def get_snrs(root_dir=DATA_DIR):\n",
|
| 60 |
+
" return _get_all_numbered_dirs(root_dir)\n",
|
| 61 |
+
" \n",
|
| 62 |
+
" \n",
|
| 63 |
+
"def process_metadata(metadata):\n",
|
| 64 |
+
" scaled_metadata = [\n",
|
| 65 |
+
" {\n",
|
| 66 |
+
" \"position\": (SAMPLE_RATE/2 + i['fc'], i['bw']),\n",
|
| 67 |
+
" \"mod\": i[\"mod\"]\n",
|
| 68 |
+
" }\n",
|
| 69 |
+
" for i in metadata\n",
|
| 70 |
+
" ]\n",
|
| 71 |
+
" return scaled_metadata\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"def process_signal(signal):\n",
|
| 75 |
+
" signal = signal[:STFT_LENGTH]\n",
|
| 76 |
+
"\n",
|
| 77 |
+
" signal = np.fft.fft(signal)\n",
|
| 78 |
+
" signal = np.fft.fftshift(signal)\n",
|
| 79 |
+
" signal /= np.max(np.abs(signal))\n",
|
| 80 |
+
" \n",
|
| 81 |
+
" #return np.expand_dims(signal, axis=0)\n",
|
| 82 |
+
" return signal"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"id": "440b802c",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"### Data Loading"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "31bc3770",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"MASK_SIZE = int(STFT_LENGTH)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"class WidebandSignalDataset(torch.utils.data.Dataset):\n",
|
| 103 |
+
" def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snr=False):\n",
|
| 104 |
+
" self.mask_size = mask_size\n",
|
| 105 |
+
" self.signal_ids = signal_ids\n",
|
| 106 |
+
" self.return_snr = return_snr # New parameter to control SNR return\n",
|
| 107 |
+
" loaded_data = []\n",
|
| 108 |
+
" for snr, signal_id in tqdm(self.signal_ids):\n",
|
| 109 |
+
" signal, masks = self.process_signal(snr, signal_id)\n",
|
| 110 |
+
" loaded_data.append((signal, masks))\n",
|
| 111 |
+
" self.loaded_data = loaded_data\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" def __len__(self):\n",
|
| 114 |
+
" return len(self.signal_ids)\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" def __getitem__(self, index):\n",
|
| 117 |
+
" signal, masks = self.loaded_data[index]\n",
|
| 118 |
+
" if self.return_snr:\n",
|
| 119 |
+
" snr, _ = self.signal_ids[index]\n",
|
| 120 |
+
" return signal, masks, snr # Return SNR during evaluation\n",
|
| 121 |
+
" else:\n",
|
| 122 |
+
" return signal, masks # Return only signal and masks during training\n",
|
| 123 |
+
"\n",
|
| 124 |
+
" def process_signal(self, snr, signal_id):\n",
|
| 125 |
+
" signal, metadata = load_data(snr, signal_id)\n",
|
| 126 |
+
" scaled_metadata = process_metadata(metadata)\n",
|
| 127 |
+
" signal = process_signal(signal)\n",
|
| 128 |
+
" signal = torch.from_numpy(signal)\n",
|
| 129 |
+
" masks = torch.zeros(self.mask_size)\n",
|
| 130 |
+
" scale_ratio = self.mask_size / SAMPLE_RATE\n",
|
| 131 |
+
" for meta in scaled_metadata:\n",
|
| 132 |
+
" f, b = meta['position']\n",
|
| 133 |
+
" x1, x2 = math.floor((f - b / 2) * scale_ratio), math.ceil((f + b / 2) * scale_ratio)\n",
|
| 134 |
+
" masks[x1:x2] = 1\n",
|
| 135 |
+
" return signal.type(torch.complex64), masks.type(torch.FloatTensor)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# Train test split 80 - 10 - 10\n",
|
| 138 |
+
"train, test, validation = [], [], [] \n",
|
| 139 |
+
"for snr in get_snrs():\n",
|
| 140 |
+
" signals = get_signals(snr)\n",
|
| 141 |
+
" total_signals = len(signals)\n",
|
| 142 |
+
" for signal in signals:\n",
|
| 143 |
+
" if signal <= 0.8 * total_signals:\n",
|
| 144 |
+
" train.append((snr, signal))\n",
|
| 145 |
+
" elif signal <= 0.9 * total_signals:\n",
|
| 146 |
+
" validation.append((snr, signal))\n",
|
| 147 |
+
" else:\n",
|
| 148 |
+
" test.append((snr, signal))\n",
|
| 149 |
+
" \n",
|
| 150 |
+
"print(\"Train\", len(train))\n",
|
| 151 |
+
"print(\"Validation\", len(validation))\n",
|
| 152 |
+
"print(\"Test\", len(test))\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"train_dataset = WidebandSignalDataset(signal_ids=train)\n",
|
| 155 |
+
"validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
|
| 156 |
+
"test_dataset = WidebandSignalDataset(signal_ids=test)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "markdown",
|
| 161 |
+
"id": "637ae774",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"source": [
|
| 164 |
+
"### Batch Loading"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "code",
|
| 169 |
+
"execution_count": null,
|
| 170 |
+
"id": "a9af2450",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"outputs": [],
|
| 173 |
+
"source": [
|
| 174 |
+
"batch_size = 64 # Updated batch size\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 177 |
+
"valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"print(\"Train labels shape:\", len(train_dataset))\n",
|
| 180 |
+
"print(\"Validation labels shape:\", len(validation_dataset))"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "markdown",
|
| 185 |
+
"id": "9a8e09e4",
|
| 186 |
+
"metadata": {},
|
| 187 |
+
"source": [
|
| 188 |
+
"### Early Stop"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "code",
|
| 193 |
+
"execution_count": null,
|
| 194 |
+
"id": "24f79a24",
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"import os\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"class EarlyStopping:\n",
|
| 201 |
+
" def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./models/CMuSeNet'):\n",
|
| 202 |
+
" self.patience = patience\n",
|
| 203 |
+
" self.verbose = verbose\n",
|
| 204 |
+
" self.delta = delta\n",
|
| 205 |
+
" self.counter = 0\n",
|
| 206 |
+
" self.best_score = None\n",
|
| 207 |
+
" self.early_stop = False\n",
|
| 208 |
+
" self.val_loss_min = float('inf')\n",
|
| 209 |
+
" self.best_model = None\n",
|
| 210 |
+
" self.save_path = save_path\n",
|
| 211 |
+
" os.makedirs(save_path, exist_ok=True)\n",
|
| 212 |
+
" \n",
|
| 213 |
+
" def __call__(self, val_loss, model):\n",
|
| 214 |
+
" score = -val_loss\n",
|
| 215 |
+
"\n",
|
| 216 |
+
" if self.best_score is None:\n",
|
| 217 |
+
" self.best_score = score\n",
|
| 218 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 219 |
+
" elif score < self.best_score + self.delta:\n",
|
| 220 |
+
" self.counter += 1\n",
|
| 221 |
+
" if self.verbose:\n",
|
| 222 |
+
" print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
|
| 223 |
+
" if self.counter >= self.patience:\n",
|
| 224 |
+
" self.early_stop = True\n",
|
| 225 |
+
" else:\n",
|
| 226 |
+
" self.best_score = score\n",
|
| 227 |
+
" self.save_checkpoint(val_loss, model)\n",
|
| 228 |
+
" self.counter = 0\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" def save_checkpoint(self, val_loss, model):\n",
|
| 231 |
+
" if self.verbose:\n",
|
| 232 |
+
" print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
|
| 233 |
+
" self.val_loss_min = val_loss\n",
|
| 234 |
+
" self.best_model = model.state_dict()\n",
|
| 235 |
+
" save_path = os.path.join(self.save_path, 'best_model.pth')\n",
|
| 236 |
+
" torch.save(self.best_model, save_path)"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "markdown",
|
| 241 |
+
"id": "6c3fda74",
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"source": [
|
| 244 |
+
"### Reshape"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "code",
|
| 249 |
+
"execution_count": null,
|
| 250 |
+
"id": "5fcf91db",
|
| 251 |
+
"metadata": {},
|
| 252 |
+
"outputs": [],
|
| 253 |
+
"source": [
|
| 254 |
+
"import torch.nn as nn\n",
|
| 255 |
+
"import complexPyTorch.complexLayers as cplx\n",
|
| 256 |
+
"import torch.nn.functional as F\n",
|
| 257 |
+
"import torch\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"def reshape_to_2d(data):\n",
|
| 260 |
+
" return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "markdown",
|
| 265 |
+
"id": "b7d7562c",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"source": [
|
| 268 |
+
"### Complex IoU"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": null,
|
| 274 |
+
"id": "7218c3f3",
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"def calculate_iou(pred, target, threshold=0.5):\n",
|
| 279 |
+
" real_pred = (pred.real > threshold).float()\n",
|
| 280 |
+
" imag_pred = (pred.imag > threshold).float()\n",
|
| 281 |
+
" \n",
|
| 282 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
|
| 283 |
+
" \n",
|
| 284 |
+
" intersection = (combined_pred * target).sum(dim=1)\n",
|
| 285 |
+
" union = (combined_pred + target).sum(dim=1) - intersection\n",
|
| 286 |
+
" iou = (intersection / union).mean().item()\n",
|
| 287 |
+
" return iou"
|
| 288 |
+
]
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"cell_type": "markdown",
|
| 292 |
+
"id": "64f4063c",
|
| 293 |
+
"metadata": {},
|
| 294 |
+
"source": [
|
| 295 |
+
"### Training"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"cell_type": "code",
|
| 300 |
+
"execution_count": null,
|
| 301 |
+
"id": "66825110",
|
| 302 |
+
"metadata": {},
|
| 303 |
+
"outputs": [],
|
| 304 |
+
"source": [
|
| 305 |
+
"import time\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"def validate_model(model, valid_loader, criterion):\n",
|
| 308 |
+
" model.eval()\n",
|
| 309 |
+
" running_loss = 0.0\n",
|
| 310 |
+
" iou_scores = []\n",
|
| 311 |
+
" total_correct = 0\n",
|
| 312 |
+
" total_samples = 0\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" with torch.no_grad():\n",
|
| 315 |
+
" for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
|
| 316 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 317 |
+
" masks = masks.to(device)\n",
|
| 318 |
+
" outputs = model(inputs)\n",
|
| 319 |
+
" loss = criterion(outputs, masks)\n",
|
| 320 |
+
" running_loss += loss.item()\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" # Calculate IoU\n",
|
| 323 |
+
" iou = calculate_iou(outputs, masks, threshold=0.5)\n",
|
| 324 |
+
" iou_scores.append(iou)\n",
|
| 325 |
+
" \n",
|
| 326 |
+
" # Calculate accuracy\n",
|
| 327 |
+
" preds = ((outputs.real > 0.5) & (outputs.imag > 0.5)).float()\n",
|
| 328 |
+
" correct = (preds == masks).float().sum()\n",
|
| 329 |
+
" total_correct += correct.item()\n",
|
| 330 |
+
" total_samples += masks.numel()\n",
|
| 331 |
+
"\n",
|
| 332 |
+
" val_loss = running_loss / len(valid_loader)\n",
|
| 333 |
+
" mean_iou = sum(iou_scores) / len(iou_scores)\n",
|
| 334 |
+
" accuracy = total_correct / total_samples * 100\n",
|
| 335 |
+
"\n",
|
| 336 |
+
" print(f'Validation Loss: {val_loss:.6f}')\n",
|
| 337 |
+
" print(f'Validation Accuracy: {accuracy:.2f}%')\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" return val_loss, accuracy\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.00001], num_epochs=50, patience=3):\n",
|
| 342 |
+
" train_losses = []\n",
|
| 343 |
+
" val_losses = []\n",
|
| 344 |
+
" val_accuracies = []\n",
|
| 345 |
+
" epoch_durations = []\n",
|
| 346 |
+
" \n",
|
| 347 |
+
" current_lr = initial_lr\n",
|
| 348 |
+
" for lr in lr_steps:\n",
|
| 349 |
+
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
| 350 |
+
" early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
|
| 351 |
+
" print(\"Current learning rate: \", lr)\n",
|
| 352 |
+
" for epoch in range(num_epochs):\n",
|
| 353 |
+
" epoch_start_time = time.time()\n",
|
| 354 |
+
" \n",
|
| 355 |
+
" model.train()\n",
|
| 356 |
+
" running_loss = 0.0\n",
|
| 357 |
+
" for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
|
| 358 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 359 |
+
" masks = masks.to(device)\n",
|
| 360 |
+
" outputs = model(inputs)\n",
|
| 361 |
+
" loss = criterion(outputs, masks)\n",
|
| 362 |
+
"\n",
|
| 363 |
+
" optimizer.zero_grad()\n",
|
| 364 |
+
" loss.backward()\n",
|
| 365 |
+
" optimizer.step()\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" running_loss += loss.item()\n",
|
| 368 |
+
"\n",
|
| 369 |
+
" epoch_loss = running_loss / len(train_loader)\n",
|
| 370 |
+
" train_losses.append(epoch_loss)\n",
|
| 371 |
+
" print(f\"Training Loss: {epoch_loss:.6f}\")\n",
|
| 372 |
+
"\n",
|
| 373 |
+
" val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
|
| 374 |
+
" val_losses.append(val_loss)\n",
|
| 375 |
+
" val_accuracies.append(val_accuracy)\n",
|
| 376 |
+
" early_stopping(val_loss, model)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" if early_stopping.early_stop:\n",
|
| 379 |
+
" print(\"Early stopping triggered\")\n",
|
| 380 |
+
" break\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" epoch_duration = time.time() - epoch_start_time\n",
|
| 383 |
+
" epoch_durations.append(epoch_duration)\n",
|
| 384 |
+
" if early_stopping.best_model is not None:\n",
|
| 385 |
+
" print(f\"Loading best model from lr {lr}\")\n",
|
| 386 |
+
" model.load_state_dict(early_stopping.best_model)\n",
|
| 387 |
+
" \n",
|
| 388 |
+
" print(\"Training completed.\")\n",
|
| 389 |
+
" print(\"Epoch durations:\", epoch_durations)\n",
|
| 390 |
+
" return model, train_losses, val_losses, val_accuracies, epoch_durations"
|
| 391 |
+
]
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"cell_type": "markdown",
|
| 395 |
+
"id": "0b80cb51",
|
| 396 |
+
"metadata": {},
|
| 397 |
+
"source": [
|
| 398 |
+
"### ResNet-18"
|
| 399 |
+
]
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"cell_type": "code",
|
| 403 |
+
"execution_count": null,
|
| 404 |
+
"id": "2d208cb9",
|
| 405 |
+
"metadata": {},
|
| 406 |
+
"outputs": [],
|
| 407 |
+
"source": [
|
| 408 |
+
"import torch\n",
|
| 409 |
+
"import torch.nn as nn\n",
|
| 410 |
+
"import complexPyTorch.complexLayers as cplx\n",
|
| 411 |
+
"from typing import Optional, Callable, Type, Union, List\n",
|
| 412 |
+
"import torch.nn.functional as F\n",
|
| 413 |
+
"from torch import Tensor\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 416 |
+
" \"\"\"3x3 convolution with padding\"\"\"\n",
|
| 417 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
|
| 420 |
+
" \"\"\"1x1 convolution\"\"\"\n",
|
| 421 |
+
" return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"class BasicBlock(nn.Module):\n",
|
| 424 |
+
" expansion = 1\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" def __init__(\n",
|
| 427 |
+
" self,\n",
|
| 428 |
+
" inplanes: int,\n",
|
| 429 |
+
" planes: int,\n",
|
| 430 |
+
" stride: int = 1,\n",
|
| 431 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 432 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 433 |
+
" ) -> None:\n",
|
| 434 |
+
" super(BasicBlock, self).__init__()\n",
|
| 435 |
+
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
|
| 436 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 437 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 438 |
+
" self.conv2 = conv3x3(planes, planes)\n",
|
| 439 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 440 |
+
" self.downsample = downsample\n",
|
| 441 |
+
" self.stride = stride\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 444 |
+
" identity = x\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" out = self.conv1(x)\n",
|
| 447 |
+
" out = self.bn1(out)\n",
|
| 448 |
+
" out = self.relu(out)\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" out = self.conv2(out)\n",
|
| 451 |
+
" out = self.bn2(out)\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" if self.downsample is not None:\n",
|
| 454 |
+
" identity = self.downsample(x)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" out += identity\n",
|
| 457 |
+
" out = self.relu(out)\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" return out\n",
|
| 460 |
+
"\n",
|
| 461 |
+
"class Bottleneck(nn.Module):\n",
|
| 462 |
+
" expansion = 4\n",
|
| 463 |
+
"\n",
|
| 464 |
+
" def __init__(\n",
|
| 465 |
+
" self,\n",
|
| 466 |
+
" inplanes: int,\n",
|
| 467 |
+
" planes: int,\n",
|
| 468 |
+
" stride: int = 1,\n",
|
| 469 |
+
" downsample: Optional[nn.Module] = None,\n",
|
| 470 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 471 |
+
" ) -> None:\n",
|
| 472 |
+
" super(Bottleneck, self).__init__()\n",
|
| 473 |
+
" self.conv1 = conv1x1(inplanes, planes)\n",
|
| 474 |
+
" self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 475 |
+
" self.conv2 = conv3x3(planes, planes, stride)\n",
|
| 476 |
+
" self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
|
| 477 |
+
" self.conv3 = conv1x1(planes, planes * self.expansion)\n",
|
| 478 |
+
" self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
|
| 479 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 480 |
+
" self.downsample = downsample\n",
|
| 481 |
+
" self.stride = stride\n",
|
| 482 |
+
"\n",
|
| 483 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 484 |
+
" identity = x\n",
|
| 485 |
+
"\n",
|
| 486 |
+
" out = self.conv1(x)\n",
|
| 487 |
+
" out = self.bn1(out)\n",
|
| 488 |
+
" out = self.relu(out)\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" out = self.conv2(out)\n",
|
| 491 |
+
" out = self.bn2(out)\n",
|
| 492 |
+
" out = self.relu(out)\n",
|
| 493 |
+
"\n",
|
| 494 |
+
" out = self.conv3(out)\n",
|
| 495 |
+
" out = self.bn3(out)\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" if self.downsample is not None:\n",
|
| 498 |
+
" identity = self.downsample(x)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" out += identity\n",
|
| 501 |
+
" out = self.relu(out)\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" return out\n",
|
| 504 |
+
"\n",
|
| 505 |
+
"class ComplexResNet(nn.Module):\n",
|
| 506 |
+
" def __init__(\n",
|
| 507 |
+
" self,\n",
|
| 508 |
+
" block: Type[Union[BasicBlock, Bottleneck]],\n",
|
| 509 |
+
" layers: List[int],\n",
|
| 510 |
+
" num_classes: int = STFT_LENGTH,\n",
|
| 511 |
+
" zero_init_residual: bool = False,\n",
|
| 512 |
+
" groups: int = 1,\n",
|
| 513 |
+
" width_per_group: int = 64,\n",
|
| 514 |
+
" norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
|
| 515 |
+
" ) -> None:\n",
|
| 516 |
+
" super(ComplexResNet, self).__init__()\n",
|
| 517 |
+
" if norm_layer is None:\n",
|
| 518 |
+
" norm_layer = cplx.ComplexBatchNorm2d\n",
|
| 519 |
+
" self._norm_layer = norm_layer\n",
|
| 520 |
+
"\n",
|
| 521 |
+
" self.inplanes = 64\n",
|
| 522 |
+
" self.dilation = 1\n",
|
| 523 |
+
"\n",
|
| 524 |
+
" self.groups = groups\n",
|
| 525 |
+
" self.base_width = width_per_group\n",
|
| 526 |
+
" self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
|
| 527 |
+
" self.bn1 = norm_layer(self.inplanes)\n",
|
| 528 |
+
" self.relu = cplx.ComplexReLU()\n",
|
| 529 |
+
" self.layer1 = self._make_layer(block, 64, layers[0])\n",
|
| 530 |
+
" self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
|
| 531 |
+
" self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
|
| 532 |
+
" self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
|
| 533 |
+
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
|
| 534 |
+
" self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
|
| 535 |
+
" self.sigmoid = cplx.ComplexSigmoid()\n",
|
| 536 |
+
"\n",
|
| 537 |
+
" def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
|
| 538 |
+
" norm_layer = self._norm_layer\n",
|
| 539 |
+
" downsample = None\n",
|
| 540 |
+
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
|
| 541 |
+
" downsample = nn.Sequential(\n",
|
| 542 |
+
" conv1x1(self.inplanes, planes * block.expansion, stride),\n",
|
| 543 |
+
" norm_layer(planes * block.expansion),\n",
|
| 544 |
+
" )\n",
|
| 545 |
+
"\n",
|
| 546 |
+
" layers = []\n",
|
| 547 |
+
" layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
|
| 548 |
+
" self.inplanes = planes * block.expansion\n",
|
| 549 |
+
" for _ in range(1, blocks):\n",
|
| 550 |
+
" layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
|
| 551 |
+
"\n",
|
| 552 |
+
" return nn.Sequential(*layers)\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" def _forward_impl(self, x: Tensor) -> Tensor:\n",
|
| 555 |
+
" x = self.conv1(x)\n",
|
| 556 |
+
" x = self.bn1(x)\n",
|
| 557 |
+
" x = self.relu(x)\n",
|
| 558 |
+
"\n",
|
| 559 |
+
" x = self.layer1(x)\n",
|
| 560 |
+
" x = self.layer2(x)\n",
|
| 561 |
+
" x = self.layer3(x)\n",
|
| 562 |
+
" x = self.layer4(x)\n",
|
| 563 |
+
"\n",
|
| 564 |
+
" x = self.avgpool(x)\n",
|
| 565 |
+
" x = torch.flatten(x, 1)\n",
|
| 566 |
+
" x = self.fc(x)\n",
|
| 567 |
+
" x = self.sigmoid(x)\n",
|
| 568 |
+
" return x\n",
|
| 569 |
+
"\n",
|
| 570 |
+
" def forward(self, x: Tensor) -> Tensor:\n",
|
| 571 |
+
" return self._forward_impl(x)\n",
|
| 572 |
+
"\n",
|
| 573 |
+
"def ComplexResNet18():\n",
|
| 574 |
+
" return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"# Create the model instance\n",
|
| 577 |
+
"model = ComplexResNet18()\n",
|
| 578 |
+
"print(model)\n"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
"cell_type": "markdown",
|
| 583 |
+
"id": "e4bc1b5d",
|
| 584 |
+
"metadata": {},
|
| 585 |
+
"source": [
|
| 586 |
+
"### Complex focal Loss"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"cell_type": "code",
|
| 591 |
+
"execution_count": null,
|
| 592 |
+
"id": "61c29429",
|
| 593 |
+
"metadata": {},
|
| 594 |
+
"outputs": [],
|
| 595 |
+
"source": [
|
| 596 |
+
"class ComplexFocalLoss(nn.Module):\n",
|
| 597 |
+
" def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
|
| 598 |
+
" super(ComplexFocalLoss, self).__init__()\n",
|
| 599 |
+
" self.alpha = alpha\n",
|
| 600 |
+
" self.gamma = gamma\n",
|
| 601 |
+
" self.reduction = reduction\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" def forward(self, inputs, targets):\n",
|
| 604 |
+
" real_inputs = inputs.real\n",
|
| 605 |
+
" imag_inputs = inputs.imag\n",
|
| 606 |
+
" \n",
|
| 607 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
|
| 608 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
|
| 609 |
+
" \n",
|
| 610 |
+
" real_pt = torch.exp(-real_BCE_loss)\n",
|
| 611 |
+
" imag_pt = torch.exp(-imag_BCE_loss)\n",
|
| 612 |
+
" \n",
|
| 613 |
+
" real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
|
| 614 |
+
" imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
|
| 615 |
+
"\n",
|
| 616 |
+
" if self.reduction == 'mean':\n",
|
| 617 |
+
" return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
|
| 618 |
+
" elif self.reduction == 'sum':\n",
|
| 619 |
+
" return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
|
| 620 |
+
" else:\n",
|
| 621 |
+
" return real_F_loss + imag_F_loss\n",
|
| 622 |
+
"\n",
|
| 623 |
+
"# Update the IoU calculation to handle complex values\n",
|
| 624 |
+
"def calculate_iou(pred, target, threshold=0.5):\n",
|
| 625 |
+
" real_pred = (pred.real > threshold).float()\n",
|
| 626 |
+
" imag_pred = (pred.imag > threshold).float()\n",
|
| 627 |
+
" \n",
|
| 628 |
+
" combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
|
| 629 |
+
" \n",
|
| 630 |
+
" intersection = (combined_pred * target).sum(dim=1)\n",
|
| 631 |
+
" union = (combined_pred + target).sum(dim=1) - intersection\n",
|
| 632 |
+
" iou = (intersection / union).mean().item()\n",
|
| 633 |
+
" return iou"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"cell_type": "markdown",
|
| 638 |
+
"id": "abb35ba2",
|
| 639 |
+
"metadata": {},
|
| 640 |
+
"source": [
|
| 641 |
+
"### Training with complex focal loss"
|
| 642 |
+
]
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"cell_type": "code",
|
| 646 |
+
"execution_count": null,
|
| 647 |
+
"id": "86d7526b",
|
| 648 |
+
"metadata": {},
|
| 649 |
+
"outputs": [],
|
| 650 |
+
"source": [
|
| 651 |
+
"# Initialize and train the CResNet-18 model\n",
|
| 652 |
+
"model = ComplexResNet18().to(device)\n",
|
| 653 |
+
"criterion = ComplexFocalLoss()\n",
|
| 654 |
+
"\n",
|
| 655 |
+
"# Train the model and validate it\n",
|
| 656 |
+
"#0.001, 0.0001, 0.00001, 0.000001\n",
|
| 657 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
|
| 658 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 659 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 660 |
+
]
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"cell_type": "markdown",
|
| 664 |
+
"id": "fd0c9d58",
|
| 665 |
+
"metadata": {},
|
| 666 |
+
"source": [
|
| 667 |
+
"### CVNN RV-BCE and CV-BCE Loss function implementation"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"cell_type": "code",
|
| 672 |
+
"execution_count": null,
|
| 673 |
+
"id": "99c736b8",
|
| 674 |
+
"metadata": {},
|
| 675 |
+
"outputs": [],
|
| 676 |
+
"source": [
|
| 677 |
+
"# RV BCE Loss Function Definition\n",
|
| 678 |
+
"class RealValuedBCELoss(nn.Module):\n",
|
| 679 |
+
" def __init__(self, reduction='mean'):\n",
|
| 680 |
+
" super(RealValuedBCELoss, self).__init__()\n",
|
| 681 |
+
" self.reduction = reduction\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" def forward(self, inputs, targets):\n",
|
| 684 |
+
" # Use only the real part of the complex inputs\n",
|
| 685 |
+
" real_inputs = inputs.real\n",
|
| 686 |
+
" BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
|
| 687 |
+
" return BCE_loss\n",
|
| 688 |
+
"\n",
|
| 689 |
+
" \n",
|
| 690 |
+
"# CV BCE Loss Function Definition\n",
|
| 691 |
+
"class ComplexValuedBCELoss(nn.Module):\n",
|
| 692 |
+
" def __init__(self, reduction='mean'):\n",
|
| 693 |
+
" super(ComplexValuedBCELoss, self).__init__()\n",
|
| 694 |
+
" self.reduction = reduction\n",
|
| 695 |
+
"\n",
|
| 696 |
+
" def forward(self, inputs, targets):\n",
|
| 697 |
+
" real_inputs = inputs.real\n",
|
| 698 |
+
" imag_inputs = inputs.imag\n",
|
| 699 |
+
"\n",
|
| 700 |
+
" # Calculate binary cross-entropy for both real and imaginary parts\n",
|
| 701 |
+
" real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
|
| 702 |
+
" imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
|
| 703 |
+
" \n",
|
| 704 |
+
" # Combine the losses (you can adjust the weighting if necessary)\n",
|
| 705 |
+
" combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
|
| 706 |
+
" return combined_BCE_loss"
|
| 707 |
+
]
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
"cell_type": "markdown",
|
| 711 |
+
"id": "d6930f39",
|
| 712 |
+
"metadata": {},
|
| 713 |
+
"source": [
|
| 714 |
+
"### RV-BCE Training"
|
| 715 |
+
]
|
| 716 |
+
},
|
| 717 |
+
{
|
| 718 |
+
"cell_type": "code",
|
| 719 |
+
"execution_count": null,
|
| 720 |
+
"id": "9e59d4c9",
|
| 721 |
+
"metadata": {},
|
| 722 |
+
"outputs": [],
|
| 723 |
+
"source": [
|
| 724 |
+
"# Set the criterion for RV BCE\n",
|
| 725 |
+
"criterion = RealValuedBCELoss()\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"# Train the ResNet-18 model with RV BCE\n",
|
| 728 |
+
"device = torch.device('cuda')\n",
|
| 729 |
+
"model = ComplexResNet18().to(device)\n",
|
| 730 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"# Start training with the previously defined train_model function\n",
|
| 733 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
|
| 734 |
+
" model, train_loader, valid_loader, criterion, \n",
|
| 735 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001, 0.00001, 0.000001], num_epochs=50, patience=3\n",
|
| 736 |
+
")\n"
|
| 737 |
+
]
|
| 738 |
+
},
|
| 739 |
+
{
|
| 740 |
+
"cell_type": "markdown",
|
| 741 |
+
"id": "93d19ea7",
|
| 742 |
+
"metadata": {},
|
| 743 |
+
"source": [
|
| 744 |
+
"### CV-BCE Training"
|
| 745 |
+
]
|
| 746 |
+
},
|
| 747 |
+
{
|
| 748 |
+
"cell_type": "code",
|
| 749 |
+
"execution_count": null,
|
| 750 |
+
"id": "2c56d5b4",
|
| 751 |
+
"metadata": {},
|
| 752 |
+
"outputs": [],
|
| 753 |
+
"source": [
|
| 754 |
+
"# Set the criterion for CV BCE\n",
|
| 755 |
+
"criterion = ComplexValuedBCELoss()\n",
|
| 756 |
+
"\n",
|
| 757 |
+
"# Train the ResNet-18 model with CV BCE\n",
|
| 758 |
+
"device = torch.device('cuda')\n",
|
| 759 |
+
"model = ComplexResNet18().to(device)\n",
|
| 760 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
| 761 |
+
"\n",
|
| 762 |
+
"# Start training with the previously defined train_model function\n",
|
| 763 |
+
"model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
|
| 764 |
+
" model, train_loader, valid_loader, criterion, \n",
|
| 765 |
+
" initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
|
| 766 |
+
")\n",
|
| 767 |
+
"combined_epoch_time = sum(epoch_durations)\n",
|
| 768 |
+
"print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
|
| 769 |
+
]
|
| 770 |
+
},
|
| 771 |
+
{
|
| 772 |
+
"cell_type": "markdown",
|
| 773 |
+
"id": "f4f6530e",
|
| 774 |
+
"metadata": {},
|
| 775 |
+
"source": [
|
| 776 |
+
"### Plot training result (Accuracy, loss vs epoch)"
|
| 777 |
+
]
|
| 778 |
+
},
|
| 779 |
+
{
|
| 780 |
+
"cell_type": "code",
|
| 781 |
+
"execution_count": null,
|
| 782 |
+
"id": "43676a01",
|
| 783 |
+
"metadata": {},
|
| 784 |
+
"outputs": [],
|
| 785 |
+
"source": [
|
| 786 |
+
"import matplotlib.pyplot as plt\n",
|
| 787 |
+
"import json\n",
|
| 788 |
+
"import os\n",
|
| 789 |
+
"\n",
|
| 790 |
+
"# Ensure the directory exists\n",
|
| 791 |
+
"output_dir = 'cvnn_results/segmentation'\n",
|
| 792 |
+
"os.makedirs(output_dir, exist_ok=True)\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"def save_metrics_to_json(train_losses, val_accuracies, epoch_durations, filename):\n",
|
| 795 |
+
" \"\"\"\n",
|
| 796 |
+
" Save the training losses and validation accuracies to a JSON file.\n",
|
| 797 |
+
" \n",
|
| 798 |
+
" Args:\n",
|
| 799 |
+
" train_losses (list): List of training losses.\n",
|
| 800 |
+
" val_accuracies (list): List of validation accuracies.\n",
|
| 801 |
+
" filename (str): The file name for the JSON file.\n",
|
| 802 |
+
" \"\"\"\n",
|
| 803 |
+
" metrics = {\n",
|
| 804 |
+
" \"train_losses\": train_losses,\n",
|
| 805 |
+
" \"val_accuracies\": val_accuracies,\n",
|
| 806 |
+
" \"epoch_durations\": epoch_durations\n",
|
| 807 |
+
" }\n",
|
| 808 |
+
" with open(os.path.join(output_dir, filename), 'w') as f:\n",
|
| 809 |
+
" json.dump(metrics, f)\n",
|
| 810 |
+
"\n",
|
| 811 |
+
"def plot_training_metrics(train_losses, val_accuracies, plot_filename):\n",
|
| 812 |
+
" \"\"\"\n",
|
| 813 |
+
" Plot the training loss and validation accuracy, and mark the epoch where accuracy reaches 99%.\n",
|
| 814 |
+
" \n",
|
| 815 |
+
" Args:\n",
|
| 816 |
+
" train_losses (list): List of training losses.\n",
|
| 817 |
+
" val_accuracies (list): List of validation accuracies.\n",
|
| 818 |
+
" plot_filename (str): The file name for saving the plot as SVG.\n",
|
| 819 |
+
" \"\"\"\n",
|
| 820 |
+
" epochs = range(1, len(train_losses) + 1)\n",
|
| 821 |
+
"\n",
|
| 822 |
+
" plt.figure(figsize=(14, 6))\n",
|
| 823 |
+
"\n",
|
| 824 |
+
" # Plot Training Loss\n",
|
| 825 |
+
" plt.subplot(1, 2, 1)\n",
|
| 826 |
+
" plt.plot(epochs, train_losses, label='Training Loss')\n",
|
| 827 |
+
" plt.xlabel('Epochs')\n",
|
| 828 |
+
" plt.ylabel('Loss')\n",
|
| 829 |
+
" plt.title('Training Loss')\n",
|
| 830 |
+
" plt.legend()\n",
|
| 831 |
+
"\n",
|
| 832 |
+
" # Plot Validation Accuracy\n",
|
| 833 |
+
" plt.subplot(1, 2, 2)\n",
|
| 834 |
+
" plt.plot(epochs, val_accuracies, label='Validation Accuracy')\n",
|
| 835 |
+
" plt.xlabel('Epochs')\n",
|
| 836 |
+
" plt.ylabel('Accuracy (%)')\n",
|
| 837 |
+
" plt.title('Validation Accuracy')\n",
|
| 838 |
+
" plt.legend()\n",
|
| 839 |
+
"\n",
|
| 840 |
+
" # Find the first epoch where validation accuracy reaches or exceeds 99%\n",
|
| 841 |
+
" for i, acc in enumerate(val_accuracies):\n",
|
| 842 |
+
" if acc >= 99:\n",
|
| 843 |
+
" first_99_epoch = i + 1 # Epochs are 1-based\n",
|
| 844 |
+
" plt.axvline(first_99_epoch, color='r', linestyle='--', label=f'99% reached at epoch {first_99_epoch}')\n",
|
| 845 |
+
" break\n",
|
| 846 |
+
"\n",
|
| 847 |
+
" plt.legend()\n",
|
| 848 |
+
" plt.tight_layout()\n",
|
| 849 |
+
"\n",
|
| 850 |
+
" # Save the plot as an SVG file\n",
|
| 851 |
+
" plt.savefig(os.path.join(output_dir, plot_filename), format='svg')\n",
|
| 852 |
+
" plt.show()\n",
|
| 853 |
+
"\n",
|
| 854 |
+
"# Save the metrics to JSON in cvnn_results/segmentation\n",
|
| 855 |
+
"save_metrics_to_json(train_losses, val_accuracies, epoch_durations, 'training_metrics.json')\n",
|
| 856 |
+
"\n",
|
| 857 |
+
"# Plot the metrics and highlight when accuracy reaches 99%, saving the plot as SVG\n",
|
| 858 |
+
"plot_training_metrics(train_losses, val_accuracies, 'training_metrics_plot.svg')"
|
| 859 |
+
]
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"cell_type": "markdown",
|
| 863 |
+
"id": "c6f4ea75",
|
| 864 |
+
"metadata": {},
|
| 865 |
+
"source": [
|
| 866 |
+
"### Evaluation "
|
| 867 |
+
]
|
| 868 |
+
},
|
| 869 |
+
{
|
| 870 |
+
"cell_type": "code",
|
| 871 |
+
"execution_count": null,
|
| 872 |
+
"id": "a303080e",
|
| 873 |
+
"metadata": {},
|
| 874 |
+
"outputs": [],
|
| 875 |
+
"source": [
|
| 876 |
+
"# Load the pre-trained model for evaluation\n",
|
| 877 |
+
"import torch\n",
|
| 878 |
+
"\n",
|
| 879 |
+
"device = \"cuda\"\n",
|
| 880 |
+
"\n",
|
| 881 |
+
"model_path = \"path/to/the/model\" #Please change this to the model path you trained\n",
|
| 882 |
+
"model = ComplexResNet18().to(device)\n",
|
| 883 |
+
"model.load_state_dict(torch.load(model_path, map_location=device))\n",
|
| 884 |
+
"model.eval()\n"
|
| 885 |
+
]
|
| 886 |
+
},
|
| 887 |
+
{
|
| 888 |
+
"cell_type": "code",
|
| 889 |
+
"execution_count": null,
|
| 890 |
+
"id": "0590b6ef",
|
| 891 |
+
"metadata": {},
|
| 892 |
+
"outputs": [],
|
| 893 |
+
"source": [
|
| 894 |
+
"import torch\n",
|
| 895 |
+
"from tqdm import tqdm\n",
|
| 896 |
+
"from torch.utils.data import DataLoader\n",
|
| 897 |
+
"import numpy as np\n",
|
| 898 |
+
"\n",
|
| 899 |
+
"# Define thresholds for recall calculation\n",
|
| 900 |
+
"iou_thresholds = [0.5, 0.7, 0.9]\n",
|
| 901 |
+
"\n",
|
| 902 |
+
"# Initialize metrics\n",
|
| 903 |
+
"snr_results = {}\n",
|
| 904 |
+
"total_accuracy = 0.0\n",
|
| 905 |
+
"total_samples = 0\n",
|
| 906 |
+
"iou_scores = {th: 0.0 for th in iou_thresholds}\n",
|
| 907 |
+
"recall_counts = {th: 0 for th in iou_thresholds}\n",
|
| 908 |
+
"BATCH_SIZE = 64\n",
|
| 909 |
+
"# Create DataLoader for the entire dataset\n",
|
| 910 |
+
"full_dataset = WidebandSignalDataset(signal_ids=train + validation + test, return_snr=True)\n",
|
| 911 |
+
"full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)"
|
| 912 |
+
]
|
| 913 |
+
},
|
| 914 |
+
{
|
| 915 |
+
"cell_type": "markdown",
|
| 916 |
+
"id": "6db6a18f",
|
| 917 |
+
"metadata": {},
|
| 918 |
+
"source": [
|
| 919 |
+
"### Bounding Box"
|
| 920 |
+
]
|
| 921 |
+
},
|
| 922 |
+
{
|
| 923 |
+
"cell_type": "code",
|
| 924 |
+
"execution_count": null,
|
| 925 |
+
"id": "e396c72c",
|
| 926 |
+
"metadata": {},
|
| 927 |
+
"outputs": [],
|
| 928 |
+
"source": [
|
| 929 |
+
"import torch\n",
|
| 930 |
+
"from collections import defaultdict\n",
|
| 931 |
+
"import time\n",
|
| 932 |
+
"from tqdm import tqdm\n",
|
| 933 |
+
"import torch\n",
|
| 934 |
+
"import torch.nn.functional as F\n",
|
| 935 |
+
"from scipy.optimize import linear_sum_assignment\n",
|
| 936 |
+
"\n",
|
| 937 |
+
"def expand_true(array, distance=1):\n",
|
| 938 |
+
" # Create kernel of appropriate size\n",
|
| 939 |
+
" kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
|
| 940 |
+
" array = array.unsqueeze(1).float() # Add channel dimension\n",
|
| 941 |
+
" result = F.conv1d(array, kernel, padding=distance)\n",
|
| 942 |
+
" result = result.squeeze(1) # Remove the extra dimension\n",
|
| 943 |
+
" \n",
|
| 944 |
+
" # Convert values greater than 0 to `True`\n",
|
| 945 |
+
" return result > 0\n",
|
| 946 |
+
"\n",
|
| 947 |
+
"# Define supporting functions based on your friend's code\n",
|
| 948 |
+
"def get_true_groups(tensor, device):\n",
|
| 949 |
+
" assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
|
| 950 |
+
" all_groups = []\n",
|
| 951 |
+
" for i in range(tensor.size(0)):\n",
|
| 952 |
+
" item = tensor[i]\n",
|
| 953 |
+
" item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
|
| 954 |
+
" diffs = item.float().diff()\n",
|
| 955 |
+
" starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
|
| 956 |
+
" ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
|
| 957 |
+
" groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
|
| 958 |
+
" all_groups.append(groups)\n",
|
| 959 |
+
" return all_groups\n",
|
| 960 |
+
"\n",
|
| 961 |
+
"def get_target_boxes(metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n",
|
| 962 |
+
" scale_ratio = number_of_bins / sample_rate\n",
|
| 963 |
+
" targets = []\n",
|
| 964 |
+
" masks = torch.zeros(number_of_bins)\n",
|
| 965 |
+
" for meta in metadata:\n",
|
| 966 |
+
" f, b = meta['position']\n",
|
| 967 |
+
" x1, x2 = math.floor((f-b/2)*scale_ratio), math.ceil((f+b/2)*scale_ratio)\n",
|
| 968 |
+
" masks[x1:x2] = 1\n",
|
| 969 |
+
" targets.append((x1, x2))\n",
|
| 970 |
+
" return targets, masks\n",
|
| 971 |
+
"\n",
|
| 972 |
+
"def get_target_boxes_batch(batch_metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n",
|
| 973 |
+
" all_targets, all_masks = [], []\n",
|
| 974 |
+
" for metadata in batch_metadata:\n",
|
| 975 |
+
" targets, masks = get_target_boxes(metadata, number_of_bins, sample_rate)\n",
|
| 976 |
+
" all_targets.append(targets)\n",
|
| 977 |
+
" all_masks.append(masks)\n",
|
| 978 |
+
" return all_targets, all_masks\n",
|
| 979 |
+
"\n",
|
| 980 |
+
"def calculate_iou(box1, box2):\n",
|
| 981 |
+
" intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
|
| 982 |
+
" union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
|
| 983 |
+
" return intersection / union if union != 0 else 0\n",
|
| 984 |
+
"\n",
|
| 985 |
+
"def match_targets(targets, preds):\n",
|
| 986 |
+
" ious = []\n",
|
| 987 |
+
" for target in targets:\n",
|
| 988 |
+
" iou_targets = []\n",
|
| 989 |
+
" for pred in preds:\n",
|
| 990 |
+
" iou_targets.append(calculate_iou(target, pred))\n",
|
| 991 |
+
" ious.append(iou_targets)\n",
|
| 992 |
+
" return linear_sum_assignment(ious, maximize=True)\n",
|
| 993 |
+
"\n",
|
| 994 |
+
"def match_targets_batch(batch_targets, batch_preds):\n",
|
| 995 |
+
" all_assignments = []\n",
|
| 996 |
+
" for targets, preds in zip(batch_targets, batch_preds):\n",
|
| 997 |
+
" all_assignments.append(match_targets(targets, preds))\n",
|
| 998 |
+
" return all_assignments\n",
|
| 999 |
+
"\n",
|
| 1000 |
+
"def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
|
| 1001 |
+
" ious = [0 for _ in target_boxes]\n",
|
| 1002 |
+
" matching_dict = dict(zip(*matching))\n",
|
| 1003 |
+
" for target_index, target_box in enumerate(target_boxes):\n",
|
| 1004 |
+
" if target_index in matching_dict:\n",
|
| 1005 |
+
" box1 = target_box\n",
|
| 1006 |
+
" box2 = prediction_boxes[matching_dict[target_index]]\n",
|
| 1007 |
+
" ious[target_index] = calculate_iou(box1, box2)\n",
|
| 1008 |
+
" return ious\n",
|
| 1009 |
+
"\n",
|
| 1010 |
+
"def calculate_matched_iou_mean_batch(batch_target_boxes, batch_pred_boxes, batch_matching):\n",
|
| 1011 |
+
" all_ious = []\n",
|
| 1012 |
+
" for args in zip(batch_target_boxes, batch_pred_boxes, batch_matching):\n",
|
| 1013 |
+
" all_ious.append(calculate_matched_ious(*args))\n",
|
| 1014 |
+
" return all_ious\n",
|
| 1015 |
+
"\n"
|
| 1016 |
+
]
|
| 1017 |
+
},
|
| 1018 |
+
{
|
| 1019 |
+
"cell_type": "code",
|
| 1020 |
+
"execution_count": null,
|
| 1021 |
+
"id": "24d483c1",
|
| 1022 |
+
"metadata": {},
|
| 1023 |
+
"outputs": [],
|
| 1024 |
+
"source": [
|
| 1025 |
+
"from collections import defaultdict\n",
|
| 1026 |
+
"from tqdm import tqdm\n",
|
| 1027 |
+
"def model_predictor(signals):\n",
|
| 1028 |
+
" # Use the already loaded model and apply thresholding\n",
|
| 1029 |
+
" signals = reshape_to_2d(signals)\n",
|
| 1030 |
+
" outputs = model(signals)\n",
|
| 1031 |
+
" return expand_true(outputs.real > 0.5) # Use real part for thresholding\n",
|
| 1032 |
+
"def evaluate(predictor, data_loader, device=\"cuda\"):\n",
|
| 1033 |
+
" snr_metrics = defaultdict(lambda: {\n",
|
| 1034 |
+
" \"iou_sum\": 0.0,\n",
|
| 1035 |
+
" \"iou_count\": 0,\n",
|
| 1036 |
+
" \"recall_counts\": defaultdict(int),\n",
|
| 1037 |
+
" \"total_samples\": defaultdict(int),\n",
|
| 1038 |
+
" \"correct_pixels\": 0,\n",
|
| 1039 |
+
" \"total_pixels\": 0\n",
|
| 1040 |
+
" })\n",
|
| 1041 |
+
" total_iou_sum, total_iou_count = 0.0, 0\n",
|
| 1042 |
+
" total_correct_pixels, total_total_pixels = 0, 0\n",
|
| 1043 |
+
" total_recall_counts = defaultdict(int)\n",
|
| 1044 |
+
" total_samples = defaultdict(int)\n",
|
| 1045 |
+
"\n",
|
| 1046 |
+
" for inputs, masks, snrs_in_batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
|
| 1047 |
+
" #inputs = inputs.to(device)\n",
|
| 1048 |
+
" inputs = reshape_to_2d(inputs).to(device)\n",
|
| 1049 |
+
" masks = masks.to(device)\n",
|
| 1050 |
+
" outputs = predictor(inputs)\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
" for i in range(len(snrs_in_batch)):\n",
|
| 1053 |
+
" snr = snrs_in_batch[i].item()\n",
|
| 1054 |
+
" mask = masks[i]\n",
|
| 1055 |
+
" output = outputs[i]\n",
|
| 1056 |
+
"\n",
|
| 1057 |
+
" # Ensure output matches mask shape\n",
|
| 1058 |
+
" if output.numel() != mask.numel():\n",
|
| 1059 |
+
" output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
|
| 1060 |
+
"\n",
|
| 1061 |
+
" thresholded_output = (output.real >= 0.5).float()\n",
|
| 1062 |
+
"\n",
|
| 1063 |
+
" correct_pixels = (thresholded_output == mask).sum().item()\n",
|
| 1064 |
+
" total_pixels = mask.numel()\n",
|
| 1065 |
+
" snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
|
| 1066 |
+
" snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
|
| 1067 |
+
" total_correct_pixels += correct_pixels\n",
|
| 1068 |
+
" total_total_pixels += total_pixels\n",
|
| 1069 |
+
"\n",
|
| 1070 |
+
" target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
|
| 1071 |
+
" pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
|
| 1072 |
+
" if not target_boxes or not pred_boxes:\n",
|
| 1073 |
+
" continue\n",
|
| 1074 |
+
" matching = match_targets(target_boxes, pred_boxes)\n",
|
| 1075 |
+
" matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
" snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
|
| 1078 |
+
" snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
|
| 1079 |
+
" total_iou_sum += sum(matched_ious)\n",
|
| 1080 |
+
" total_iou_count += len(matched_ious)\n",
|
| 1081 |
+
"\n",
|
| 1082 |
+
" for th in iou_thresholds:\n",
|
| 1083 |
+
" true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
|
| 1084 |
+
" snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
|
| 1085 |
+
" snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
|
| 1086 |
+
" total_recall_counts[th] += true_positives\n",
|
| 1087 |
+
" total_samples[th] += len(target_boxes)\n",
|
| 1088 |
+
"\n",
|
| 1089 |
+
" # Calculate overall metrics\n",
|
| 1090 |
+
" overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
|
| 1091 |
+
" overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
|
| 1092 |
+
" overall_recall = {th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0 for th in iou_thresholds}\n",
|
| 1093 |
+
"\n",
|
| 1094 |
+
" # Print overall results\n",
|
| 1095 |
+
" print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
|
| 1096 |
+
" print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
|
| 1097 |
+
" for th in iou_thresholds:\n",
|
| 1098 |
+
" print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
" # Print per-SNR results\n",
|
| 1101 |
+
" for snr, metrics in sorted(snr_metrics.items()):\n",
|
| 1102 |
+
" snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
|
| 1103 |
+
" snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
|
| 1104 |
+
" print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
|
| 1105 |
+
" print(f\" IoU: {snr_iou:.4f}\")\n",
|
| 1106 |
+
" for th in iou_thresholds:\n",
|
| 1107 |
+
" recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
|
| 1108 |
+
" print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
|
| 1109 |
+
"\n",
|
| 1110 |
+
" return snr_metrics\n"
|
| 1111 |
+
]
|
| 1112 |
+
},
|
| 1113 |
+
{
|
| 1114 |
+
"cell_type": "code",
|
| 1115 |
+
"execution_count": null,
|
| 1116 |
+
"id": "a71c18ba",
|
| 1117 |
+
"metadata": {
|
| 1118 |
+
"scrolled": false
|
| 1119 |
+
},
|
| 1120 |
+
"outputs": [],
|
| 1121 |
+
"source": [
|
| 1122 |
+
"snr_metrics = evaluate(model_predictor, full_loader, device=device)"
|
| 1123 |
+
]
|
| 1124 |
+
},
|
| 1125 |
+
{
|
| 1126 |
+
"cell_type": "markdown",
|
| 1127 |
+
"id": "87417c7b",
|
| 1128 |
+
"metadata": {},
|
| 1129 |
+
"source": [
|
| 1130 |
+
"### Plot and Save"
|
| 1131 |
+
]
|
| 1132 |
+
},
|
| 1133 |
+
{
|
| 1134 |
+
"cell_type": "code",
|
| 1135 |
+
"execution_count": null,
|
| 1136 |
+
"id": "1dbfb5e6",
|
| 1137 |
+
"metadata": {
|
| 1138 |
+
"scrolled": false
|
| 1139 |
+
},
|
| 1140 |
+
"outputs": [],
|
| 1141 |
+
"source": [
|
| 1142 |
+
"import json\n",
|
| 1143 |
+
"import matplotlib.pyplot as plt\n",
|
| 1144 |
+
"from pathlib import Path\n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"# Define the path for saving the JSON file and plots\n",
|
| 1147 |
+
"save_path = Path(\"CMuSeNet_plots/Synthetic\")\n",
|
| 1148 |
+
"save_path.mkdir(parents=True, exist_ok=True)\n",
|
| 1149 |
+
"json_file_path = save_path / \"evaluation_results.json\"\n",
|
| 1150 |
+
"\n",
|
| 1151 |
+
"# Save metrics and plot results\n",
|
| 1152 |
+
"def save_and_plot_results(snr_metrics, iou_thresholds):\n",
|
| 1153 |
+
" # Prepare data for plotting and JSON saving\n",
|
| 1154 |
+
" snr_values = sorted(snr_metrics.keys())\n",
|
| 1155 |
+
" iou_scores = [snr_metrics[snr][\"iou_sum\"] / snr_metrics[snr][\"iou_count\"] if snr_metrics[snr][\"iou_count\"] > 0 else 0 for snr in snr_values]\n",
|
| 1156 |
+
" accuracies = [(snr_metrics[snr][\"correct_pixels\"] / snr_metrics[snr][\"total_pixels\"]) * 100 if snr_metrics[snr][\"total_pixels\"] > 0 else 0 for snr in snr_values]\n",
|
| 1157 |
+
" recalls = {th: [(snr_metrics[snr][\"recall_counts\"][th] / snr_metrics[snr][\"total_samples\"][th]) if snr_metrics[snr][\"total_samples\"][th] > 0 else 0 for snr in snr_values] for th in iou_thresholds}\n",
|
| 1158 |
+
"\n",
|
| 1159 |
+
" # Save results to JSON\n",
|
| 1160 |
+
" results = {\n",
|
| 1161 |
+
" \"SNR\": snr_values,\n",
|
| 1162 |
+
" \"IoU_Scores\": iou_scores,\n",
|
| 1163 |
+
" \"Accuracy\": accuracies,\n",
|
| 1164 |
+
" \"Recall\": {str(th): recalls[th] for th in iou_thresholds}\n",
|
| 1165 |
+
" }\n",
|
| 1166 |
+
" with open(json_file_path, \"w\") as f:\n",
|
| 1167 |
+
" json.dump(results, f, indent=4)\n",
|
| 1168 |
+
" print(f\"Results saved to {json_file_path}\")\n",
|
| 1169 |
+
"\n",
|
| 1170 |
+
" # Plot IoU vs SNR\n",
|
| 1171 |
+
" plt.figure()\n",
|
| 1172 |
+
" plt.plot(snr_values, iou_scores, marker='o', label=\"IoU Score\")\n",
|
| 1173 |
+
" plt.xlabel(\"SNR (dB)\")\n",
|
| 1174 |
+
" plt.ylabel(\"IoU Score\")\n",
|
| 1175 |
+
" plt.title(\"IoU Score vs. SNR\")\n",
|
| 1176 |
+
" plt.grid(True)\n",
|
| 1177 |
+
" plt.legend()\n",
|
| 1178 |
+
" plt.savefig(save_path / \"IoU_vs_SNR.png\")\n",
|
| 1179 |
+
" plt.savefig(save_path / \"IoU_vs_SNR.svg\")\n",
|
| 1180 |
+
" plt.show()\n",
|
| 1181 |
+
"\n",
|
| 1182 |
+
" # Plot Accuracy vs SNR\n",
|
| 1183 |
+
" plt.figure()\n",
|
| 1184 |
+
" plt.plot(snr_values, accuracies, marker='o', label=\"Accuracy\")\n",
|
| 1185 |
+
" plt.xlabel(\"SNR (dB)\")\n",
|
| 1186 |
+
" plt.ylabel(\"Accuracy (%)\")\n",
|
| 1187 |
+
" plt.title(\"Accuracy vs. SNR (Threshold 0.5)\")\n",
|
| 1188 |
+
" plt.grid(True)\n",
|
| 1189 |
+
" plt.legend()\n",
|
| 1190 |
+
" plt.savefig(save_path / \"Accuracy_vs_SNR.png\")\n",
|
| 1191 |
+
" plt.savefig(save_path / \"Accuracy_vs_SNR.svg\")\n",
|
| 1192 |
+
" plt.show()\n",
|
| 1193 |
+
"\n",
|
| 1194 |
+
" # Plot Recall vs SNR for each threshold\n",
|
| 1195 |
+
" for th in iou_thresholds:\n",
|
| 1196 |
+
" plt.figure()\n",
|
| 1197 |
+
" plt.plot(snr_values, recalls[th], marker='o', label=f\"Recall at {th}\")\n",
|
| 1198 |
+
" plt.xlabel(\"SNR (dB)\")\n",
|
| 1199 |
+
" plt.ylabel(\"Recall\")\n",
|
| 1200 |
+
" plt.title(f\"Recall vs. SNR (Threshold {th})\")\n",
|
| 1201 |
+
" plt.grid(True)\n",
|
| 1202 |
+
" plt.legend()\n",
|
| 1203 |
+
" plt.savefig(save_path / f\"Recall_vs_SNR_{th}.png\")\n",
|
| 1204 |
+
" plt.savefig(save_path / f\"Recall_vs_SNR_{th}.svg\")\n",
|
| 1205 |
+
" plt.show()\n",
|
| 1206 |
+
"\n",
|
| 1207 |
+
"# Call this after running evaluate() to save and plot results\n",
|
| 1208 |
+
"save_and_plot_results(snr_metrics, iou_thresholds)"
|
| 1209 |
+
]
|
| 1210 |
+
},
|
| 1211 |
+
{
|
| 1212 |
+
"cell_type": "code",
|
| 1213 |
+
"execution_count": null,
|
| 1214 |
+
"id": "d0c0d3e8",
|
| 1215 |
+
"metadata": {},
|
| 1216 |
+
"outputs": [],
|
| 1217 |
+
"source": []
|
| 1218 |
+
}
|
| 1219 |
+
],
|
| 1220 |
+
"metadata": {
|
| 1221 |
+
"kernelspec": {
|
| 1222 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1223 |
+
"language": "python",
|
| 1224 |
+
"name": "python3"
|
| 1225 |
+
},
|
| 1226 |
+
"language_info": {
|
| 1227 |
+
"codemirror_mode": {
|
| 1228 |
+
"name": "ipython",
|
| 1229 |
+
"version": 3
|
| 1230 |
+
},
|
| 1231 |
+
"file_extension": ".py",
|
| 1232 |
+
"mimetype": "text/x-python",
|
| 1233 |
+
"name": "python",
|
| 1234 |
+
"nbconvert_exporter": "python",
|
| 1235 |
+
"pygments_lexer": "ipython3",
|
| 1236 |
+
"version": "3.10.9"
|
| 1237 |
+
}
|
| 1238 |
+
},
|
| 1239 |
+
"nbformat": 4,
|
| 1240 |
+
"nbformat_minor": 5
|
| 1241 |
+
}
|
CMuSeNet_Synthetic_IQ_Generator/README.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This matlab function is generated and tested in MATLAB 2022, and 2024
|
| 2 |
+
|
| 3 |
+
Please open datagen.m script and run it with MATLAB to generate synthetic dataset with same configuration as CMuSeNet synthetic dataset.
|
| 4 |
+
|
| 5 |
+
In this script you can change various setting such as channel (AWGN, Rician, Rayleigh), sample speed, range of SNR and sample bandwidth.
|
| 6 |
+
|
| 7 |
+
This dataset is used to train CMuSeNet, complex-valued multi-signla segmentation Network.
|
| 8 |
+
|
| 9 |
+
Please cite our paper if you use this dataset or synthetic dataset generation script.
|
| 10 |
+
|
| 11 |
+
@inproceedings{shin2025cmusenet,
|
| 12 |
+
title={I Can't Believe It's Not Real: {CV-MuSeNet}: Complex-Valued Multi-Signal Segmentation},
|
| 13 |
+
author={Sangwon Shin and Mehmet C. Vuran},
|
| 14 |
+
booktitle={IEEE Dynamic Spectrum Access Networks (DySPAN)},
|
| 15 |
+
year={2025},
|
| 16 |
+
organization={IEEE}
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
Acknowledgement: Office of Naval Research, NSWC Crane N00174-23-1-0007
|
| 20 |
+
This work relates to Department of Navy award N00174-23-1-0007 issued by the Office of Naval Research, NSWC Crane. Any opinions,
|
| 21 |
+
findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of the Office of Naval Research.
|
| 22 |
+
|
| 23 |
+
Following IQ samples generation script is coded by Prashant Subedi, Sangwon Shin and Dr. Mehmet Can Vuran - Cyber Physical Networking (CPN) Lab at University of Nebraska - Lincoln
|
| 24 |
+
|
| 25 |
+
License:
|
| 26 |
+
This IQ samples is licensed under the GPL family (General Public License) terms.
|
CMuSeNet_Synthetic_IQ_Generator/datagen.m
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path = "../diff-snr-matlab-simulated-data";
|
| 2 |
+
|
| 3 |
+
for snr = -20:2:10
|
| 4 |
+
disp(snr);
|
| 5 |
+
mkdir(sprintf("%s/%d/", path, snr));
|
| 6 |
+
for i = 1:5000
|
| 7 |
+
name = string(i);
|
| 8 |
+
channelType = 'awgn'; %Supported channel type: awgn, rician (Flat), rayleigh (Flat)
|
| 9 |
+
[meta, data] = datagenWideband(snr, channelType);
|
| 10 |
+
split = reshape([real(data) imag(data)].', 1, []);
|
| 11 |
+
|
| 12 |
+
% Save data file
|
| 13 |
+
mkdir(sprintf("%s/%d/%s", path, snr, name));
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
datafile = fopen(sprintf("%s/%d/%s/data.dat", path, snr, name), 'w');
|
| 17 |
+
fwrite(datafile, split, 'double');
|
| 18 |
+
fclose(datafile);
|
| 19 |
+
|
| 20 |
+
% Save meta file
|
| 21 |
+
metafile = fopen(sprintf("%s/%d/%s/meta-data.json", path, snr, name), 'w');
|
| 22 |
+
fprintf(metafile, jsonencode(meta));
|
| 23 |
+
fclose(metafile);
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
disp(name);
|
| 27 |
+
end
|
| 28 |
+
end
|
CMuSeNet_Synthetic_IQ_Generator/datagenTransmitter.m
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function transmittedSignal = datagenTransmitter( ...
|
| 2 |
+
modulation, ...
|
| 3 |
+
rolloffFactor, ...
|
| 4 |
+
filterSpanInSymbols, ...
|
| 5 |
+
samplesPerSymbol, ...
|
| 6 |
+
symbolRate, ...
|
| 7 |
+
messageDuration ...
|
| 8 |
+
)
|
| 9 |
+
requiresFilter = true;
|
| 10 |
+
if modulation == "QPSK"
|
| 11 |
+
bitsPerSymbol = 2;
|
| 12 |
+
modulator = comm.QPSKModulator( ...
|
| 13 |
+
'BitInput', true, ...
|
| 14 |
+
'PhaseOffset', pi/4, ...
|
| 15 |
+
'OutputDataType', 'double' ...
|
| 16 |
+
);
|
| 17 |
+
elseif modulation == "BPSK"
|
| 18 |
+
bitsPerSymbol = 1;
|
| 19 |
+
modulator = comm.BPSKModulator;
|
| 20 |
+
elseif modulation == "8-PSK"
|
| 21 |
+
bitsPerSymbol = 3;
|
| 22 |
+
modulator = @(x) qammod(bit2int(x, 3), 8);
|
| 23 |
+
elseif modulation == "8-QAM"
|
| 24 |
+
bitsPerSymbol = 3;
|
| 25 |
+
modulator = @(x) pskmod(bit2int(x, 3), 8);
|
| 26 |
+
elseif modulation == "16-QAM"
|
| 27 |
+
bitsPerSymbol = 4;
|
| 28 |
+
modulator = @(x) qammod(bit2int(x, 4), 16);
|
| 29 |
+
elseif modulation == "GMSK"
|
| 30 |
+
bitsPerSymbol = 1;
|
| 31 |
+
modulator = comm.GMSKModulator("SamplesPerSymbol", samplesPerSymbol, ...
|
| 32 |
+
"BitInput", true);
|
| 33 |
+
requiresFilter = false;
|
| 34 |
+
elseif modulation == "2-FSK"
|
| 35 |
+
bitsPerSymbol = 1;
|
| 36 |
+
fdev = floor(symbolRate/4);
|
| 37 |
+
samplesPerSymbol = 8;
|
| 38 |
+
modulator = @(x) fskmod(x, 2, fdev, samplesPerSymbol, symbolRate);
|
| 39 |
+
requiresFilter = false;
|
| 40 |
+
else
|
| 41 |
+
error("Not implemented " + modulation);
|
| 42 |
+
end
|
| 43 |
+
|
| 44 |
+
transmittedBin = randi( ...
|
| 45 |
+
[0 1], ...
|
| 46 |
+
bitsPerSymbol * symbolRate * messageDuration/samplesPerSymbol, ...
|
| 47 |
+
1 ...
|
| 48 |
+
);
|
| 49 |
+
|
| 50 |
+
modulatedData = modulator(transmittedBin); % Modulates the bits into QPSK symbols
|
| 51 |
+
|
| 52 |
+
if requiresFilter
|
| 53 |
+
transmitterFilter = comm.RaisedCosineTransmitFilter( ...
|
| 54 |
+
'RolloffFactor', rolloffFactor, ...
|
| 55 |
+
'FilterSpanInSymbols', filterSpanInSymbols, ...
|
| 56 |
+
'OutputSamplesPerSymbol', samplesPerSymbol ...
|
| 57 |
+
);
|
| 58 |
+
transmittedSignal = transmitterFilter(modulatedData); % Square root Raised Cosine Transmit Filter
|
| 59 |
+
else
|
| 60 |
+
transmittedSignal = modulatedData;
|
| 61 |
+
end
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
end
|
CMuSeNet_Synthetic_IQ_Generator/datagenWideband.m
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function [metadata, widebandSignal] = datagenWideband(SNRdB, fadingType)
|
| 2 |
+
% Constant for this function
|
| 3 |
+
RolloffFactor = 0.35;
|
| 4 |
+
RaisedCosineFilterSpan = 10;
|
| 5 |
+
Interpolation = 2;
|
| 6 |
+
NarrowBandBWs = [1e5, 2e5, 5e5, 1e6, 2e6];
|
| 7 |
+
WideBandBW = 20e6;
|
| 8 |
+
MaxSignals = 10;
|
| 9 |
+
% Modulations = ["QPSK" "BPSK" "8-PSK" "8-QAM" "16-QAM" "2-FSK" ];
|
| 10 |
+
Modulations = ["QPSK" "BPSK" "8-PSK" "8-QAM" "16-QAM", "GMSK", "2-FSK"];
|
| 11 |
+
SamplingTime = 2/1000; % 2ms
|
| 12 |
+
|
| 13 |
+
TxPowerRange = [0, 20];
|
| 14 |
+
|
| 15 |
+
numberOfSignals = randi([1, MaxSignals], 1);
|
| 16 |
+
|
| 17 |
+
signalBW = randsample(NarrowBandBWs, numberOfSignals, true);
|
| 18 |
+
txPowers = randi(TxPowerRange, [numberOfSignals, 1]);
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
minGap = 100e3; % 100kHz
|
| 22 |
+
|
| 23 |
+
maxBW = max(signalBW);
|
| 24 |
+
|
| 25 |
+
% Allocate a space for the frequencies
|
| 26 |
+
freqOffsets = [];
|
| 27 |
+
usedFreqs = [];
|
| 28 |
+
% A mechanism to prevent it from being stuck if there are too many
|
| 29 |
+
% wideband signals
|
| 30 |
+
maxLoops = numberOfSignals * 10;
|
| 31 |
+
% Generate non-overlapping frequencies
|
| 32 |
+
for i = 1:numberOfSignals
|
| 33 |
+
bw = signalBW(i);
|
| 34 |
+
% Generate a random frequency offset within the limits
|
| 35 |
+
while maxLoops > 0
|
| 36 |
+
maxLoops = maxLoops - 1; % prevent it from handing
|
| 37 |
+
freq = randi([-WideBandBW/2 + bw/2, WideBandBW/2 - bw/2]);
|
| 38 |
+
% Check if the frequency space for the new signal is already occupied or
|
| 39 |
+
% if the new signal is within minGap of an existing signal
|
| 40 |
+
overlap = false;
|
| 41 |
+
for j = 1:length(usedFreqs)
|
| 42 |
+
existing_bw = signalBW(j);
|
| 43 |
+
if abs(freq - usedFreqs(j)) < (bw + existing_bw)/2 + minGap
|
| 44 |
+
overlap = true;
|
| 45 |
+
break;
|
| 46 |
+
end
|
| 47 |
+
end
|
| 48 |
+
if ~overlap
|
| 49 |
+
% If not, add the frequency to the used frequencies and break the loop
|
| 50 |
+
usedFreqs = [usedFreqs freq];
|
| 51 |
+
freqOffsets = [freqOffsets freq];
|
| 52 |
+
break
|
| 53 |
+
end
|
| 54 |
+
% If the frequency space is occupied or too close to another signal,
|
| 55 |
+
% generate a new random frequency
|
| 56 |
+
end
|
| 57 |
+
|
| 58 |
+
if maxLoops <= 0
|
| 59 |
+
numberOfSignals = length(freqOffsets);
|
| 60 |
+
disp("Stopping because couldn't place signal");
|
| 61 |
+
disp(signalBW);
|
| 62 |
+
break;
|
| 63 |
+
end
|
| 64 |
+
|
| 65 |
+
end
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
signals = [];
|
| 69 |
+
metadata = [];
|
| 70 |
+
|
| 71 |
+
lowestPowerSignal = min(txPowers);
|
| 72 |
+
noisePower = min(txPowers) - SNRdB;
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
for i = 1: numberOfSignals
|
| 76 |
+
modulation = randsample(Modulations, 1);
|
| 77 |
+
txPower = txPowers(i);
|
| 78 |
+
bw = signalBW(i);
|
| 79 |
+
% Should the divisor be 20 ?
|
| 80 |
+
signal = datagenTransmitter( ...
|
| 81 |
+
modulation, ...
|
| 82 |
+
RolloffFactor, ...
|
| 83 |
+
RaisedCosineFilterSpan, ...
|
| 84 |
+
Interpolation, ...
|
| 85 |
+
bw, ...
|
| 86 |
+
SamplingTime...811
|
| 87 |
+
);
|
| 88 |
+
|
| 89 |
+
% Scale the signal
|
| 90 |
+
signal = signal/sqrt(mean(abs(signal).^2));
|
| 91 |
+
|
| 92 |
+
% Scale to correct power
|
| 93 |
+
signal = 10^(txPower/20)*signal;
|
| 94 |
+
|
| 95 |
+
pwr = 10*log10(mean(abs(signal).^2));
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if bw ~= maxBW
|
| 99 |
+
signal = resample(signal, maxBW/1e5, bw/1e5);
|
| 100 |
+
end
|
| 101 |
+
signals = [signals signal];
|
| 102 |
+
metadata = [metadata; struct("fc", freqOffsets(i), "bw", bw, "mod", modulation, "txPower", txPower, "noisePower", noisePower)];
|
| 103 |
+
|
| 104 |
+
end
|
| 105 |
+
mbc = comm.MultibandCombiner( ...
|
| 106 |
+
InputSampleRate=maxBW, ...
|
| 107 |
+
FrequencyOffsets=freqOffsets, ...
|
| 108 |
+
OutputSampleRateSource="property", ...
|
| 109 |
+
OutputSampleRate=WideBandBW ...
|
| 110 |
+
);
|
| 111 |
+
|
| 112 |
+
combinedsig = mbc(signals);
|
| 113 |
+
% Channel configuration
|
| 114 |
+
fd = 30; % Max Doppler shift in Hz
|
| 115 |
+
Ts = 1/WideBandBW; % Sampling time
|
| 116 |
+
chan = [];
|
| 117 |
+
|
| 118 |
+
switch lower(fadingType)
|
| 119 |
+
case 'awgn'
|
| 120 |
+
% Just noise without fading
|
| 121 |
+
widebandSignal = awgn(combinedsig, SNRdB, lowestPowerSignal);
|
| 122 |
+
|
| 123 |
+
case 'rayleigh'
|
| 124 |
+
rayleighChan = comm.RayleighChannel( ...
|
| 125 |
+
'SampleRate', WideBandBW, ...
|
| 126 |
+
'PathDelays', 0, ...
|
| 127 |
+
'AveragePathGains', 0, ...
|
| 128 |
+
'MaximumDopplerShift', 30 ...
|
| 129 |
+
);
|
| 130 |
+
fadedSignal = rayleighChan(combinedsig);
|
| 131 |
+
widebandSignal = awgn(fadedSignal, SNRdB, lowestPowerSignal); % Add AWGN
|
| 132 |
+
|
| 133 |
+
case 'rician'
|
| 134 |
+
ricianChan = comm.RicianChannel( ...
|
| 135 |
+
'SampleRate', WideBandBW, ...
|
| 136 |
+
'PathDelays', 0, ...
|
| 137 |
+
'AveragePathGains', 0, ...
|
| 138 |
+
'KFactor', 10, ...
|
| 139 |
+
'MaximumDopplerShift', 30 ...
|
| 140 |
+
);
|
| 141 |
+
fadedSignal = ricianChan(combinedsig);
|
| 142 |
+
widebandSignal = awgn(fadedSignal, SNRdB, lowestPowerSignal); % Add AWGN
|
| 143 |
+
|
| 144 |
+
otherwise
|
| 145 |
+
error('Unsupported fading type: %s', fadingType);
|
| 146 |
+
end
|
| 147 |
+
end
|