sshinCPN commited on
Commit
6e644c1
·
verified ·
1 Parent(s): 7914754

CMuSeNet Training / Validation code and Synthetic IQ samples generator

Browse files
CMuSeNet_BIGRED.ipynb ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b5007b71",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Initialization"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "3e6b1226",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from pathlib import Path\n",
19
+ "import numpy as np\n",
20
+ "from scipy.signal import welch\n",
21
+ "import torch\n",
22
+ "from torch.utils.data import Dataset, DataLoader\n",
23
+ "from tqdm import tqdm\n",
24
+ "import math\n",
25
+ "import json\n",
26
+ "\n",
27
+ "# Constants\n",
28
+ "START_INDEX = 10 # Skip first few samples\n",
29
+ "SIGNAL_LENGTH = 1024 * 16\n",
30
+ "SAMPLE_RATE = 20e6\n",
31
+ "MASK_SIZE = 1024 * 16 # Mask size for segmentation\n",
32
+ "\n",
33
+ "# Functions for Signal Processing\n",
34
+ "def load_real_data(sample_path):\n",
35
+ " \"\"\"\n",
36
+ " Load raw signal data from a .dat file.\n",
37
+ " \"\"\"\n",
38
+ " with open(sample_path, \"rb\") as f:\n",
39
+ " signal = np.fromfile(f, dtype=np.complex64)\n",
40
+ " return signal\n",
41
+ "\n",
42
+ "def load_data(signal_id):\n",
43
+ " \"\"\"\n",
44
+ " Load signal data and its corresponding metadata.\n",
45
+ " \"\"\"\n",
46
+ " signal = load_real_data(signal_id)\n",
47
+ " metadata_file = signal_id.with_suffix(\".json\")\n",
48
+ " if metadata_file.exists():\n",
49
+ " with open(metadata_file, \"r\") as f:\n",
50
+ " metadata = json.load(f)\n",
51
+ " else:\n",
52
+ " raise FileNotFoundError(f\"Metadata file {metadata_file} not found for signal {signal_id}\")\n",
53
+ " return signal[START_INDEX:], metadata, metadata_file\n",
54
+ "\n",
55
+ "def apply_psd(signal, Fs, NFFT):\n",
56
+ " \"\"\"\n",
57
+ " Calculate the PSD and corresponding frequencies using Welch's method.\n",
58
+ " \"\"\"\n",
59
+ " freqs, psd = welch(signal, fs=Fs, nfft=NFFT, return_onesided=False)\n",
60
+ " psd = np.fft.fftshift(psd)\n",
61
+ " freqs = np.fft.fftshift(freqs)\n",
62
+ " return psd, freqs\n",
63
+ "\n",
64
+ "def calculate_fft(signal):\n",
65
+ " \"\"\"\n",
66
+ " Calculate the FFT of the signal and return real and imaginary parts as separate channels.\n",
67
+ " \"\"\"\n",
68
+ " signal = signal[:SIGNAL_LENGTH]\n",
69
+ " signal = np.fft.fft(signal)\n",
70
+ " signal = np.fft.fftshift(signal)\n",
71
+ " signal /= np.max(np.abs(signal))\n",
72
+ " return signal"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "440b802c",
78
+ "metadata": {},
79
+ "source": [
80
+ "### Data Loading"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "31bc3770",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "# Dataset Class\n",
91
+ "class WidebandSignalDataset(Dataset):\n",
92
+ " def __init__(self, signal_ids, mask_size=1024 * 16):\n",
93
+ " \"\"\"\n",
94
+ " Initialize the dataset with signal IDs and the specified mask size.\n",
95
+ " \"\"\"\n",
96
+ " self.mask_size = mask_size\n",
97
+ " self.signal_ids = signal_ids\n",
98
+ " self.loaded_data = [self.process_signal(signal_id) for signal_id in tqdm(self.signal_ids)]\n",
99
+ "\n",
100
+ " def __len__(self):\n",
101
+ " return len(self.signal_ids)\n",
102
+ "\n",
103
+ " def __getitem__(self, index):\n",
104
+ " return self.loaded_data[index]\n",
105
+ "\n",
106
+ " def process_signal(self, signal_id):\n",
107
+ " signal, metadata, _ = load_data(signal_id)\n",
108
+ "\n",
109
+ " # Ensure signal length matches SIGNAL_LENGTH\n",
110
+ " if len(signal) < SIGNAL_LENGTH:\n",
111
+ " # Pad with zeros if the signal is shorter\n",
112
+ " signal = np.pad(signal, (0, SIGNAL_LENGTH - len(signal)), mode='constant')\n",
113
+ " elif len(signal) > SIGNAL_LENGTH:\n",
114
+ " # Truncate if the signal is longer\n",
115
+ " signal = signal[:SIGNAL_LENGTH]\n",
116
+ "\n",
117
+ " # Apply FFT\n",
118
+ " signal = np.fft.fft(signal)\n",
119
+ " signal = np.fft.fftshift(signal)\n",
120
+ " signal /= np.max(np.abs(signal)) # Normalize\n",
121
+ " complex_signal = torch.from_numpy(signal).type(torch.complex64).unsqueeze(0) # Add channel dimension\n",
122
+ "\n",
123
+ " # Create mask with fixed size\n",
124
+ " masks = torch.zeros(self.mask_size, dtype=torch.float32)\n",
125
+ " scale_ratio = self.mask_size / SAMPLE_RATE\n",
126
+ " scaled_metadata = process_metadata(metadata)\n",
127
+ " for meta in scaled_metadata:\n",
128
+ " f1, f2 = meta[\"position\"]\n",
129
+ " x1 = int(math.floor(f1 * scale_ratio))\n",
130
+ " x2 = int(math.ceil(f2 * scale_ratio))\n",
131
+ " masks[x1:x2] = 1\n",
132
+ "\n",
133
+ " return complex_signal, masks\n",
134
+ "\n",
135
+ "\n",
136
+ "\n",
137
+ "def process_metadata(metadata):\n",
138
+ " \"\"\"\n",
139
+ " Scale metadata to the dataset's frequency and bandwidth ranges.\n",
140
+ " \"\"\"\n",
141
+ " scaled_metadata = [\n",
142
+ " {\n",
143
+ " \"position\": (\n",
144
+ " math.floor((SAMPLE_RATE / 2 + i[\"fc\"] - i[\"bw\"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE),\n",
145
+ " math.ceil((SAMPLE_RATE / 2 + i[\"fc\"] + i[\"bw\"] / 2) * SIGNAL_LENGTH / SAMPLE_RATE)\n",
146
+ " ),\n",
147
+ " \"snr\": 1, # Placeholder value\n",
148
+ " \"bw\": i[\"bw\"],\n",
149
+ " \"num\": len(metadata),\n",
150
+ " \"esn0\": 1, # Placeholder value\n",
151
+ " }\n",
152
+ " for i in metadata\n",
153
+ " ]\n",
154
+ " return scaled_metadata\n",
155
+ "\n",
156
+ "# Dataset Splitting and Initialization\n",
157
+ "NEW_DATA_DIR = Path(\"/data/bigred/ofh/0\")\n",
158
+ "def get_real_signals(freq_directory):\n",
159
+ " return list(freq_directory.rglob(\"*.dat\"))\n",
160
+ "\n",
161
+ "signal_dirs = get_real_signals(NEW_DATA_DIR)\n",
162
+ "total_signals = len(signal_dirs)\n",
163
+ "\n",
164
+ "train_split = int(0.80 * total_signals)\n",
165
+ "validation_split = int(0.90 * total_signals)\n",
166
+ "\n",
167
+ "train, validation, test = (\n",
168
+ " signal_dirs[:train_split],\n",
169
+ " signal_dirs[train_split:validation_split],\n",
170
+ " signal_dirs[validation_split:]\n",
171
+ ")\n",
172
+ "\n",
173
+ "print(f\"Train set size: {len(train)}\")\n",
174
+ "print(f\"Validation set size: {len(validation)}\")\n",
175
+ "print(f\"Test set size: {len(test)}\")"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "f5305642",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "# Data Loaders\n",
186
+ "BATCH_SIZE = 64\n",
187
+ "\n",
188
+ "train_dataset = WidebandSignalDataset(signal_ids=train)\n",
189
+ "validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
190
+ "test_dataset = WidebandSignalDataset(signal_ids=test)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "id": "54a4f325",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
201
+ "valid_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
202
+ "test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "markdown",
207
+ "id": "3893c583",
208
+ "metadata": {},
209
+ "source": [
210
+ "### CV-ResNet-18"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "id": "bc2001c4",
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "import torch\n",
221
+ "import torch.nn as nn\n",
222
+ "import complexPyTorch.complexLayers as cplx\n",
223
+ "from typing import Optional, Callable, Type, Union, List\n",
224
+ "import torch.nn.functional as F\n",
225
+ "from torch import Tensor\n",
226
+ "\n",
227
+ "def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
228
+ " \"\"\"3x3 convolution with padding\"\"\"\n",
229
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
230
+ "\n",
231
+ "def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
232
+ " \"\"\"1x1 convolution\"\"\"\n",
233
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
234
+ "\n",
235
+ "class BasicBlock(nn.Module):\n",
236
+ " expansion = 1\n",
237
+ "\n",
238
+ " def __init__(\n",
239
+ " self,\n",
240
+ " inplanes: int,\n",
241
+ " planes: int,\n",
242
+ " stride: int = 1,\n",
243
+ " downsample: Optional[nn.Module] = None,\n",
244
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
245
+ " ) -> None:\n",
246
+ " super(BasicBlock, self).__init__()\n",
247
+ " self.conv1 = conv3x3(inplanes, planes, stride)\n",
248
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
249
+ " self.relu = cplx.ComplexReLU()\n",
250
+ " self.conv2 = conv3x3(planes, planes)\n",
251
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
252
+ " self.downsample = downsample\n",
253
+ " self.stride = stride\n",
254
+ "\n",
255
+ " def forward(self, x: Tensor) -> Tensor:\n",
256
+ " identity = x\n",
257
+ "\n",
258
+ " out = self.conv1(x)\n",
259
+ " out = self.bn1(out)\n",
260
+ " out = self.relu(out)\n",
261
+ "\n",
262
+ " out = self.conv2(out)\n",
263
+ " out = self.bn2(out)\n",
264
+ "\n",
265
+ " if self.downsample is not None:\n",
266
+ " identity = self.downsample(x)\n",
267
+ "\n",
268
+ " out += identity\n",
269
+ " out = self.relu(out)\n",
270
+ "\n",
271
+ " return out\n",
272
+ "\n",
273
+ "class Bottleneck(nn.Module):\n",
274
+ " expansion = 4\n",
275
+ "\n",
276
+ " def __init__(\n",
277
+ " self,\n",
278
+ " inplanes: int,\n",
279
+ " planes: int,\n",
280
+ " stride: int = 1,\n",
281
+ " downsample: Optional[nn.Module] = None,\n",
282
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
283
+ " ) -> None:\n",
284
+ " super(Bottleneck, self).__init__()\n",
285
+ " self.conv1 = conv1x1(inplanes, planes)\n",
286
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
287
+ " self.conv2 = conv3x3(planes, planes, stride)\n",
288
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
289
+ " self.conv3 = conv1x1(planes, planes * self.expansion)\n",
290
+ " self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
291
+ " self.relu = cplx.ComplexReLU()\n",
292
+ " self.downsample = downsample\n",
293
+ " self.stride = stride\n",
294
+ "\n",
295
+ " def forward(self, x: Tensor) -> Tensor:\n",
296
+ " identity = x\n",
297
+ "\n",
298
+ " out = self.conv1(x)\n",
299
+ " out = self.bn1(out)\n",
300
+ " out = self.relu(out)\n",
301
+ "\n",
302
+ " out = self.conv2(out)\n",
303
+ " out = self.bn2(out)\n",
304
+ " out = self.relu(out)\n",
305
+ "\n",
306
+ " out = self.conv3(out)\n",
307
+ " out = self.bn3(out)\n",
308
+ "\n",
309
+ " if self.downsample is not None:\n",
310
+ " identity = self.downsample(x)\n",
311
+ "\n",
312
+ " out += identity\n",
313
+ " out = self.relu(out)\n",
314
+ "\n",
315
+ " return out\n",
316
+ "\n",
317
+ "class ComplexResNet(nn.Module):\n",
318
+ " def __init__(\n",
319
+ " self,\n",
320
+ " block: Type[Union[BasicBlock, Bottleneck]],\n",
321
+ " layers: List[int],\n",
322
+ " num_classes: int = SIGNAL_LENGTH,\n",
323
+ " zero_init_residual: bool = False,\n",
324
+ " groups: int = 1,\n",
325
+ " width_per_group: int = 64,\n",
326
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
327
+ " ) -> None:\n",
328
+ " super(ComplexResNet, self).__init__()\n",
329
+ " if norm_layer is None:\n",
330
+ " norm_layer = cplx.ComplexBatchNorm2d\n",
331
+ " self._norm_layer = norm_layer\n",
332
+ "\n",
333
+ " self.inplanes = 64\n",
334
+ " self.dilation = 1\n",
335
+ "\n",
336
+ " self.groups = groups\n",
337
+ " self.base_width = width_per_group\n",
338
+ " self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
339
+ " self.bn1 = norm_layer(self.inplanes)\n",
340
+ " self.relu = cplx.ComplexReLU()\n",
341
+ " self.layer1 = self._make_layer(block, 64, layers[0])\n",
342
+ " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
343
+ " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
344
+ " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
345
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
346
+ " self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
347
+ " self.sigmoid = cplx.ComplexSigmoid()\n",
348
+ "\n",
349
+ " def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
350
+ " norm_layer = self._norm_layer\n",
351
+ " downsample = None\n",
352
+ " if stride != 1 or self.inplanes != planes * block.expansion:\n",
353
+ " downsample = nn.Sequential(\n",
354
+ " conv1x1(self.inplanes, planes * block.expansion, stride),\n",
355
+ " norm_layer(planes * block.expansion),\n",
356
+ " )\n",
357
+ "\n",
358
+ " layers = []\n",
359
+ " layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
360
+ " self.inplanes = planes * block.expansion\n",
361
+ " for _ in range(1, blocks):\n",
362
+ " layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
363
+ "\n",
364
+ " return nn.Sequential(*layers)\n",
365
+ "\n",
366
+ " def _forward_impl(self, x: Tensor) -> Tensor:\n",
367
+ " x = self.conv1(x)\n",
368
+ " x = self.bn1(x)\n",
369
+ " x = self.relu(x)\n",
370
+ "\n",
371
+ " x = self.layer1(x)\n",
372
+ " x = self.layer2(x)\n",
373
+ " x = self.layer3(x)\n",
374
+ " x = self.layer4(x)\n",
375
+ "\n",
376
+ " x = self.avgpool(x)\n",
377
+ " x = torch.flatten(x, 1)\n",
378
+ " x = self.fc(x)\n",
379
+ " x = self.sigmoid(x)\n",
380
+ " return x\n",
381
+ "\n",
382
+ " def forward(self, x: Tensor) -> Tensor:\n",
383
+ " return self._forward_impl(x)\n",
384
+ "\n",
385
+ "def ComplexResNet18():\n",
386
+ " return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
387
+ "\n",
388
+ "# Create the model instance\n",
389
+ "model = ComplexResNet18()\n",
390
+ "print(model)\n"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "id": "9a8e09e4",
396
+ "metadata": {},
397
+ "source": [
398
+ "### Early Stop"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "24f79a24",
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "import os\n",
409
+ "\n",
410
+ "class EarlyStopping:\n",
411
+ " def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./path/to/model/save'):\n",
412
+ " self.patience = patience\n",
413
+ " self.verbose = verbose\n",
414
+ " self.delta = delta\n",
415
+ " self.counter = 0\n",
416
+ " self.best_score = None\n",
417
+ " self.early_stop = False\n",
418
+ " self.val_loss_min = float('inf')\n",
419
+ " self.best_model = None\n",
420
+ " self.save_path = save_path\n",
421
+ " os.makedirs(save_path, exist_ok=True)\n",
422
+ " \n",
423
+ " def __call__(self, val_loss, model):\n",
424
+ " score = -val_loss\n",
425
+ "\n",
426
+ " if self.best_score is None:\n",
427
+ " self.best_score = score\n",
428
+ " self.save_checkpoint(val_loss, model)\n",
429
+ " elif score < self.best_score + self.delta:\n",
430
+ " self.counter += 1\n",
431
+ " if self.verbose:\n",
432
+ " print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
433
+ " if self.counter >= self.patience:\n",
434
+ " self.early_stop = True\n",
435
+ " else:\n",
436
+ " self.best_score = score\n",
437
+ " self.save_checkpoint(val_loss, model)\n",
438
+ " self.counter = 0\n",
439
+ "\n",
440
+ " def save_checkpoint(self, val_loss, model):\n",
441
+ " if self.verbose:\n",
442
+ " print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
443
+ " self.val_loss_min = val_loss\n",
444
+ " self.best_model = model.state_dict()\n",
445
+ " save_path = os.path.join(self.save_path, 'best_model.pth')\n",
446
+ " torch.save(self.best_model, save_path)"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "markdown",
451
+ "id": "6c3fda74",
452
+ "metadata": {},
453
+ "source": [
454
+ "### Focal loss and reshape"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "5fcf91db",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "class ComplexFocalLoss(nn.Module):\n",
465
+ " def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
466
+ " super(ComplexFocalLoss, self).__init__()\n",
467
+ " self.alpha = alpha\n",
468
+ " self.gamma = gamma\n",
469
+ " self.reduction = reduction\n",
470
+ "\n",
471
+ " def forward(self, inputs, targets):\n",
472
+ " real_inputs = inputs.real\n",
473
+ " imag_inputs = inputs.imag\n",
474
+ " \n",
475
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
476
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
477
+ " \n",
478
+ " real_pt = torch.exp(-real_BCE_loss)\n",
479
+ " imag_pt = torch.exp(-imag_BCE_loss)\n",
480
+ " \n",
481
+ " real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
482
+ " imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
483
+ "\n",
484
+ " if self.reduction == 'mean':\n",
485
+ " return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
486
+ " elif self.reduction == 'sum':\n",
487
+ " return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
488
+ " else:\n",
489
+ " return real_F_loss + imag_F_loss\n",
490
+ "\n",
491
+ "# Update the IoU calculation to handle complex values\n",
492
+ "def calculate_iou(pred, target, threshold=0.5):\n",
493
+ " real_pred = (pred.real > threshold).float()\n",
494
+ " imag_pred = (pred.imag > threshold).float()\n",
495
+ " \n",
496
+ " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
497
+ " \n",
498
+ " intersection = (combined_pred * target).sum(dim=1)\n",
499
+ " union = (combined_pred + target).sum(dim=1) - intersection\n",
500
+ " iou = (intersection / union).mean().item()\n",
501
+ " return iou\n",
502
+ "def reshape_to_2d(data):\n",
503
+ " return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "id": "c97635b0",
509
+ "metadata": {},
510
+ "source": [
511
+ "### BCE Loss"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "id": "2e8b2892",
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "# CV BCE Loss Function Definition\n",
522
+ "class ComplexValuedBCELoss(nn.Module):\n",
523
+ " def __init__(self, reduction='mean'):\n",
524
+ " super(ComplexValuedBCELoss, self).__init__()\n",
525
+ " self.reduction = reduction\n",
526
+ "\n",
527
+ " def forward(self, inputs, targets):\n",
528
+ " real_inputs = inputs.real\n",
529
+ " imag_inputs = inputs.imag\n",
530
+ "\n",
531
+ " # Calculate binary cross-entropy for both real and imaginary parts\n",
532
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
533
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
534
+ " \n",
535
+ " # Combine the losses (you can adjust the weighting if necessary)\n",
536
+ " combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
537
+ " return combined_BCE_loss"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "markdown",
542
+ "id": "64f4063c",
543
+ "metadata": {},
544
+ "source": [
545
+ "### Training from scratch"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": null,
551
+ "id": "66825110",
552
+ "metadata": {
553
+ "scrolled": false
554
+ },
555
+ "outputs": [],
556
+ "source": [
557
+ "import time\n",
558
+ "device=\"cuda\"\n",
559
+ "def validate_model(model, valid_loader, criterion):\n",
560
+ " model.eval()\n",
561
+ " running_loss = 0.0\n",
562
+ " iou_scores = []\n",
563
+ " total_correct = 0\n",
564
+ " total_samples = 0\n",
565
+ "\n",
566
+ " with torch.no_grad():\n",
567
+ " for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
568
+ " inputs = reshape_to_2d(inputs).to(device)\n",
569
+ " masks = masks.to(device)\n",
570
+ " outputs = model(inputs)\n",
571
+ " loss = criterion(outputs, masks)\n",
572
+ " running_loss += loss.item()\n",
573
+ "\n",
574
+ " # Calculate IoU\n",
575
+ " iou = calculate_iou(outputs, masks, threshold=0.5)\n",
576
+ " iou_scores.append(iou)\n",
577
+ " \n",
578
+ " # Calculate accuracy\n",
579
+ " preds = (outputs.real > 0.5).float()\n",
580
+ " correct = (preds == masks).float().sum()\n",
581
+ " total_correct += correct.item()\n",
582
+ " total_samples += masks.numel()\n",
583
+ "\n",
584
+ " val_loss = running_loss / len(valid_loader)\n",
585
+ " mean_iou = sum(iou_scores) / len(iou_scores)\n",
586
+ " accuracy = total_correct / total_samples * 100\n",
587
+ "\n",
588
+ " print(f'Validation Loss: {val_loss:.6f}')\n",
589
+ " print(f'Validation Accuracy: {accuracy:.2f}%')\n",
590
+ "\n",
591
+ " return val_loss, accuracy\n",
592
+ "\n",
593
+ "def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.0001], num_epochs=50, patience=5):\n",
594
+ " train_losses = []\n",
595
+ " val_losses = []\n",
596
+ " val_accuracies = []\n",
597
+ " epoch_durations = []\n",
598
+ " \n",
599
+ " current_lr = initial_lr\n",
600
+ " for lr in lr_steps:\n",
601
+ " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
602
+ " early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
603
+ " print(\"Current learning rate: \", lr)\n",
604
+ " for epoch in range(num_epochs):\n",
605
+ " epoch_start_time = time.time()\n",
606
+ " \n",
607
+ " model.train()\n",
608
+ " running_loss = 0.0\n",
609
+ " for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
610
+ " inputs = reshape_to_2d(inputs).to(device)\n",
611
+ " masks = masks.to(device)\n",
612
+ " outputs = model(inputs)\n",
613
+ " loss = criterion(outputs, masks)\n",
614
+ "\n",
615
+ " optimizer.zero_grad()\n",
616
+ " loss.backward()\n",
617
+ " optimizer.step()\n",
618
+ "\n",
619
+ " running_loss += loss.item()\n",
620
+ "\n",
621
+ " epoch_loss = running_loss / len(train_loader)\n",
622
+ " train_losses.append(epoch_loss)\n",
623
+ " print(f\"Training Loss: {epoch_loss:.6f}\")\n",
624
+ " val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
625
+ " val_losses.append(val_loss)\n",
626
+ " val_accuracies.append(val_accuracy)\n",
627
+ " early_stopping(val_loss, model)\n",
628
+ "\n",
629
+ " if early_stopping.early_stop:\n",
630
+ " print(\"Early stopping triggered\")\n",
631
+ " break\n",
632
+ "\n",
633
+ " epoch_duration = time.time() - epoch_start_time\n",
634
+ " epoch_durations.append(epoch_duration)\n",
635
+ " if early_stopping.best_model is not None:\n",
636
+ " print(f\"Loading best model from lr {lr}\")\n",
637
+ " model.load_state_dict(early_stopping.best_model)\n",
638
+ " \n",
639
+ " print(\"Training completed.\")\n",
640
+ " print(\"Epoch durations:\", epoch_durations)\n",
641
+ " return model, train_losses, val_losses, val_accuracies, epoch_durations"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "id": "621d28b3",
648
+ "metadata": {
649
+ "scrolled": false
650
+ },
651
+ "outputs": [],
652
+ "source": [
653
+ "# Initialize and train the ResNet-18 model\n",
654
+ "model = ComplexResNet18().to(device)\n",
655
+ "criterion = ComplexFocalLoss()\n",
656
+ "\n",
657
+ "model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
658
+ "combined_epoch_time = sum(epoch_durations)\n",
659
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "markdown",
664
+ "id": "3838c1bc",
665
+ "metadata": {},
666
+ "source": [
667
+ "### Transfer Learning Load pretrained model"
668
+ ]
669
+ },
670
+ {
671
+ "cell_type": "code",
672
+ "execution_count": null,
673
+ "id": "ac763e75",
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": [
677
+ "# Path to the pre-trained model weights\n",
678
+ "pretrained_model_path = \"path/to/model/save.pth\" #Change this model to trained model\n",
679
+ "device=\"cuda\"\n",
680
+ "# Initialize the model architecture\n",
681
+ "model = ComplexResNet18().to(device)\n",
682
+ "\n",
683
+ "# Load the pre-trained weights\n",
684
+ "checkpoint = torch.load(pretrained_model_path)\n",
685
+ "model.load_state_dict(checkpoint, strict=False)\n",
686
+ "\n",
687
+ "# Set all layers as trainable (if needed)\n",
688
+ "for param in model.parameters():\n",
689
+ " param.requires_grad = True"
690
+ ]
691
+ },
692
+ {
693
+ "cell_type": "code",
694
+ "execution_count": null,
695
+ "id": "1f877827",
696
+ "metadata": {
697
+ "scrolled": false
698
+ },
699
+ "outputs": [],
700
+ "source": [
701
+ "# Define a new criterion and optimizer for fine-tuning\n",
702
+ "# You may select between Focal Loss or BCE as your criterion\n",
703
+ "#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
704
+ "criterion = ComplexFocalLoss()\n",
705
+ "# Use a smaller learning rate for fine-tuning\n",
706
+ "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
707
+ "\n",
708
+ "# Train the model (fine-tuning)\n",
709
+ "model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
710
+ " model, train_loader, valid_loader, criterion,\n",
711
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
712
+ ")\n",
713
+ "combined_epoch_time = sum(epoch_durations)\n",
714
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "markdown",
719
+ "id": "f3784964",
720
+ "metadata": {},
721
+ "source": [
722
+ "### Plot Result and save the figures and json"
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "code",
727
+ "execution_count": null,
728
+ "id": "67a52e13",
729
+ "metadata": {
730
+ "scrolled": false
731
+ },
732
+ "outputs": [],
733
+ "source": [
734
+ "import os\n",
735
+ "import json\n",
736
+ "import matplotlib.pyplot as plt\n",
737
+ "\n",
738
+ "# Define save directory\n",
739
+ "save_dir = 'CMuSeNet_results/segmentation'\n",
740
+ "\n",
741
+ "# Create the directory if it doesn't exist\n",
742
+ "os.makedirs(save_dir, exist_ok=True)\n",
743
+ "\n",
744
+ "# Plot training loss\n",
745
+ "plt.figure()\n",
746
+ "plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')\n",
747
+ "plt.title('Training Loss')\n",
748
+ "plt.xlabel('Epoch')\n",
749
+ "plt.ylabel('Loss')\n",
750
+ "plt.legend()\n",
751
+ "\n",
752
+ "# Save the training loss figure as PNG and SVG\n",
753
+ "plt.savefig(os.path.join(save_dir, 'training_loss.png'))\n",
754
+ "plt.savefig(os.path.join(save_dir, 'training_loss.svg'))\n",
755
+ "\n",
756
+ "# Show the training loss plot\n",
757
+ "plt.show()\n",
758
+ "\n",
759
+ "# Plot validation accuracy\n",
760
+ "plt.figure()\n",
761
+ "plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='green')\n",
762
+ "plt.title('Validation Accuracy')\n",
763
+ "plt.xlabel('Epoch')\n",
764
+ "plt.ylabel('Accuracy')\n",
765
+ "plt.legend()\n",
766
+ "\n",
767
+ "# Save the validation accuracy figure as PNG and SVG\n",
768
+ "plt.savefig(os.path.join(save_dir, 'validation_accuracy.png'))\n",
769
+ "plt.savefig(os.path.join(save_dir, 'validation_accuracy.svg'))\n",
770
+ "\n",
771
+ "# Show the validation accuracy plot\n",
772
+ "plt.show()\n",
773
+ "\n",
774
+ "# Save the actual data to a JSON file\n",
775
+ "results = {\n",
776
+ " \"train_losses\": train_losses,\n",
777
+ " \"val_accuracies\": val_accuracies\n",
778
+ "}\n",
779
+ "\n",
780
+ "# Save JSON file\n",
781
+ "with open(os.path.join(save_dir, 'training_validation_results.json'), 'w') as f:\n",
782
+ " json.dump(results, f)\n"
783
+ ]
784
+ },
785
+ {
786
+ "cell_type": "markdown",
787
+ "id": "222069ae",
788
+ "metadata": {},
789
+ "source": [
790
+ "### BIG-RED Evaluation (Over entire dataset)"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": null,
796
+ "id": "6b178984",
797
+ "metadata": {},
798
+ "outputs": [],
799
+ "source": [
800
+ "import torch\n",
801
+ "from torch.utils.data import DataLoader\n",
802
+ "from tqdm import tqdm\n",
803
+ "# Create a DataLoader for the entire dataset\n",
804
+ "BATCH_SIZE = 64 # Adjust based on available memory\n",
805
+ "entire_dataset = WidebandSignalDataset(signal_ids=signal_dirs) # Use all signals\n",
806
+ "entire_loader = DataLoader(entire_dataset, batch_size=BATCH_SIZE, shuffle=False)"
807
+ ]
808
+ },
809
+ {
810
+ "cell_type": "code",
811
+ "execution_count": null,
812
+ "id": "2e6be59a",
813
+ "metadata": {},
814
+ "outputs": [],
815
+ "source": [
816
+ "# Path to the pre-trained model weights\n",
817
+ "pretrained_model_path = \"path/to/model/pretrained\" \n",
818
+ "device = \"cuda\" \n",
819
+ "\n",
820
+ "# Initialize the model architecture\n",
821
+ "model = ComplexResNet18().to(device)\n",
822
+ "\n",
823
+ "# Load the pre-trained weights\n",
824
+ "checkpoint = torch.load(pretrained_model_path, map_location=device)\n",
825
+ "model.load_state_dict(checkpoint, strict=False)\n",
826
+ "model.eval()\n",
827
+ "\n",
828
+ "# Function to evaluate accuracy\n",
829
+ "def evaluate_accuracy(model, data_loader):\n",
830
+ " total_correct = 0\n",
831
+ " total_samples = 0\n",
832
+ "\n",
833
+ " with torch.no_grad():\n",
834
+ " for inputs, masks in tqdm(data_loader, desc=\"Evaluating on Entire Dataset\"):\n",
835
+ " inputs = reshape_to_2d(inputs).to(device)\n",
836
+ " masks = masks.to(device)\n",
837
+ "\n",
838
+ " outputs = model(inputs)\n",
839
+ " preds = (outputs.real > 0.5).float()\n",
840
+ "\n",
841
+ " correct = (preds == masks).float().sum()\n",
842
+ " total_correct += correct.item()\n",
843
+ " total_samples += masks.numel()\n",
844
+ "\n",
845
+ " accuracy = total_correct / total_samples * 100\n",
846
+ " print(f\"Overall Accuracy on Entire Dataset: {accuracy:.2f}%\")\n",
847
+ " return accuracy\n",
848
+ "\n",
849
+ "# Run the evaluation\n",
850
+ "overall_accuracy = evaluate_accuracy(model, entire_loader)"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "markdown",
855
+ "id": "2a5a21b4",
856
+ "metadata": {},
857
+ "source": [
858
+ "### Function definitions"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": null,
864
+ "id": "b223d9b5",
865
+ "metadata": {},
866
+ "outputs": [],
867
+ "source": [
868
+ "import torch\n",
869
+ "from tqdm import tqdm\n",
870
+ "import numpy as np\n",
871
+ "from collections import defaultdict\n",
872
+ "import torch.nn.functional as F\n",
873
+ "from scipy.optimize import linear_sum_assignment\n",
874
+ "from torch.utils.data import ConcatDataset"
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "code",
879
+ "execution_count": null,
880
+ "id": "f54736ea",
881
+ "metadata": {},
882
+ "outputs": [],
883
+ "source": [
884
+ "# Load the pre-trained model for evaluation\n",
885
+ "device = \"cuda\"\n",
886
+ "model_path = \"path/to/model/save.pth\"\n",
887
+ "model = resnet18_1D().to(device)\n",
888
+ "model.load_state_dict(torch.load(model_path, map_location=device))\n",
889
+ "model.eval()\n"
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "execution_count": null,
895
+ "id": "dd5e7fee",
896
+ "metadata": {},
897
+ "outputs": [],
898
+ "source": [
899
+ "full_dataset = ConcatDataset([\n",
900
+ " WidebandSignalDataset(signal_ids=train, return_snrs=True),\n",
901
+ " WidebandSignalDataset(signal_ids=validation, return_snrs=True),\n",
902
+ " WidebandSignalDataset(signal_ids=test, return_snrs=True)\n",
903
+ "])"
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "code",
908
+ "execution_count": null,
909
+ "id": "173f9a8c",
910
+ "metadata": {},
911
+ "outputs": [],
912
+ "source": [
913
+ "full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)"
914
+ ]
915
+ },
916
+ {
917
+ "cell_type": "code",
918
+ "execution_count": null,
919
+ "id": "95f711d0",
920
+ "metadata": {},
921
+ "outputs": [],
922
+ "source": [
923
+ "def expand_true(array, distance=1):\n",
924
+ " # Create kernel of appropriate size\n",
925
+ " kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
926
+ " array = array.unsqueeze(1).float() # Add channel dimension\n",
927
+ " result = F.conv1d(array, kernel, padding=distance)\n",
928
+ " result = result.squeeze(1) # Remove the extra dimension\n",
929
+ " return result > 0\n",
930
+ "\n",
931
+ "def get_true_groups(tensor, device):\n",
932
+ " assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
933
+ " all_groups = []\n",
934
+ " for i in range(tensor.size(0)):\n",
935
+ " item = tensor[i]\n",
936
+ " item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
937
+ " diffs = item.float().diff()\n",
938
+ " starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
939
+ " ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
940
+ " groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
941
+ " all_groups.append(groups)\n",
942
+ " return all_groups\n",
943
+ "\n",
944
+ "def calculate_iou(box1, box2):\n",
945
+ " intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
946
+ " union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
947
+ " return intersection / union if union != 0 else 0\n",
948
+ "\n",
949
+ "def match_targets(targets, preds):\n",
950
+ " ious = []\n",
951
+ " for target in targets:\n",
952
+ " iou_targets = []\n",
953
+ " for pred in preds:\n",
954
+ " iou_targets.append(calculate_iou(target, pred))\n",
955
+ " ious.append(iou_targets)\n",
956
+ " cost_matrix = np.array(ious)\n",
957
+ " row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
958
+ " return row_ind, col_ind\n",
959
+ "\n",
960
+ "def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
961
+ " ious = [0 for _ in target_boxes]\n",
962
+ " matching_dict = dict(zip(*matching))\n",
963
+ " for target_index, target_box in enumerate(target_boxes):\n",
964
+ " if target_index in matching_dict:\n",
965
+ " pred_index = matching_dict[target_index]\n",
966
+ " if pred_index < len(prediction_boxes):\n",
967
+ " box1 = target_box\n",
968
+ " box2 = prediction_boxes[pred_index]\n",
969
+ " ious[target_index] = calculate_iou(box1, box2)\n",
970
+ " return ious\n"
971
+ ]
972
+ },
973
+ {
974
+ "cell_type": "code",
975
+ "execution_count": null,
976
+ "id": "40ec3d9f",
977
+ "metadata": {},
978
+ "outputs": [],
979
+ "source": [
980
+ "def evaluate(predictor, data_loader, device=\"cuda\"):\n",
981
+ " iou_thresholds = [0.5, 0.7, 0.9]\n",
982
+ " snr_metrics = defaultdict(lambda: {\n",
983
+ " \"iou_sum\": 0.0,\n",
984
+ " \"iou_count\": 0,\n",
985
+ " \"recall_counts\": defaultdict(int),\n",
986
+ " \"total_samples\": defaultdict(int),\n",
987
+ " \"correct_pixels\": 0,\n",
988
+ " \"total_pixels\": 0\n",
989
+ " })\n",
990
+ " total_iou_sum, total_iou_count = 0.0, 0\n",
991
+ " total_correct_pixels, total_total_pixels = 0, 0\n",
992
+ " total_recall_counts = defaultdict(int)\n",
993
+ " total_samples = defaultdict(int)\n",
994
+ "\n",
995
+ " for batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
996
+ " if len(batch) == 3:\n",
997
+ " inputs, masks, snrs_in_batch = batch\n",
998
+ " else:\n",
999
+ " inputs, masks = batch\n",
1000
+ " snrs_in_batch = [0] * len(inputs) # Default SNR if not provided\n",
1001
+ "\n",
1002
+ " inputs = inputs.to(device)\n",
1003
+ " masks = masks.to(device)\n",
1004
+ " outputs = predictor(inputs)\n",
1005
+ "\n",
1006
+ " for i in range(len(inputs)):\n",
1007
+ " mask = masks[i]\n",
1008
+ " output = outputs[i]\n",
1009
+ "\n",
1010
+ " # Resize output to match mask shape if necessary\n",
1011
+ " if output.numel() != mask.numel():\n",
1012
+ " output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
1013
+ "\n",
1014
+ " thresholded_output = (output >= 0.5).float()\n",
1015
+ "\n",
1016
+ " correct_pixels = (thresholded_output == mask).sum().item()\n",
1017
+ " total_pixels = mask.numel()\n",
1018
+ " total_correct_pixels += correct_pixels\n",
1019
+ " total_total_pixels += total_pixels\n",
1020
+ "\n",
1021
+ " # Get SNR value and round it to the nearest integer\n",
1022
+ " snr = snrs_in_batch[i]\n",
1023
+ " if isinstance(snr, torch.Tensor):\n",
1024
+ " snr = snr.item()\n",
1025
+ " snr = int(round(snr)) # Round SNR to the nearest integer\n",
1026
+ "\n",
1027
+ " snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
1028
+ " snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
1029
+ "\n",
1030
+ " target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
1031
+ " pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
1032
+ " if not target_boxes or not pred_boxes:\n",
1033
+ " continue\n",
1034
+ " matching = match_targets(target_boxes, pred_boxes)\n",
1035
+ " matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
1036
+ "\n",
1037
+ " snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
1038
+ " snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
1039
+ " total_iou_sum += sum(matched_ious)\n",
1040
+ " total_iou_count += len(matched_ious)\n",
1041
+ "\n",
1042
+ " for th in iou_thresholds:\n",
1043
+ " true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
1044
+ " snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
1045
+ " snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
1046
+ " total_recall_counts[th] += true_positives\n",
1047
+ " total_samples[th] += len(target_boxes)\n",
1048
+ "\n",
1049
+ " # Calculate overall metrics\n",
1050
+ " overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
1051
+ " overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
1052
+ " overall_recall = {\n",
1053
+ " th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0\n",
1054
+ " for th in iou_thresholds\n",
1055
+ " }\n",
1056
+ "\n",
1057
+ " # Print overall results\n",
1058
+ " print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
1059
+ " print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
1060
+ " for th in iou_thresholds:\n",
1061
+ " print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
1062
+ "\n",
1063
+ " # Print per-SNR results\n",
1064
+ " for snr in sorted(snr_metrics.keys()):\n",
1065
+ " metrics = snr_metrics[snr]\n",
1066
+ " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
1067
+ " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
1068
+ " print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
1069
+ " print(f\" IoU: {snr_iou:.4f}\")\n",
1070
+ " for th in iou_thresholds:\n",
1071
+ " recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
1072
+ " print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
1073
+ "\n",
1074
+ " return snr_metrics\n",
1075
+ "\n",
1076
+ "\n",
1077
+ "def model_predictor(signals):\n",
1078
+ " # Use the already loaded model and apply thresholding\n",
1079
+ " return expand_true(model(signals) > 0.5)\n"
1080
+ ]
1081
+ },
1082
+ {
1083
+ "cell_type": "code",
1084
+ "execution_count": null,
1085
+ "id": "c7d3aed7",
1086
+ "metadata": {
1087
+ "scrolled": false
1088
+ },
1089
+ "outputs": [],
1090
+ "source": [
1091
+ "# Run evaluation on the full dataset\n",
1092
+ "snr_metrics = evaluate(model_predictor, full_loader, device=device)"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "markdown",
1097
+ "id": "2fd3ba0e",
1098
+ "metadata": {},
1099
+ "source": [
1100
+ "### Save and plot"
1101
+ ]
1102
+ },
1103
+ {
1104
+ "cell_type": "code",
1105
+ "execution_count": null,
1106
+ "id": "aef69113",
1107
+ "metadata": {},
1108
+ "outputs": [],
1109
+ "source": [
1110
+ "import os\n",
1111
+ "import json\n",
1112
+ "import matplotlib.pyplot as plt\n",
1113
+ "\n",
1114
+ "def save_results_and_plot(snr_metrics, save_path):\n",
1115
+ " \"\"\"\n",
1116
+ " Saves evaluation results to a JSON file and generates plots for Accuracy, IoU, and Recall vs. SNR.\n",
1117
+ " Sets x-axis limits to range from -9 dB to 12 dB to eliminate blank space on the right.\n",
1118
+ "\n",
1119
+ " Args:\n",
1120
+ " snr_metrics (dict): The evaluation results obtained from the evaluate function.\n",
1121
+ " save_path (str): The directory path where results and plots will be saved.\n",
1122
+ "\n",
1123
+ " Outputs:\n",
1124
+ " - evaluation_results.json\n",
1125
+ " - accuracy_vs_snr.png and .svg\n",
1126
+ " - iou_vs_snr.png and .svg\n",
1127
+ " - recall_vs_snr.png and .svg\n",
1128
+ " \"\"\"\n",
1129
+ " # Ensure the directory exists\n",
1130
+ " os.makedirs(save_path, exist_ok=True)\n",
1131
+ " \n",
1132
+ " # Extract data from snr_metrics\n",
1133
+ " snr_list = sorted(snr_metrics.keys())\n",
1134
+ " accuracy_list = []\n",
1135
+ " iou_list = []\n",
1136
+ " recall_05 = []\n",
1137
+ " recall_07 = []\n",
1138
+ " recall_09 = []\n",
1139
+ " \n",
1140
+ " # Prepare data for JSON serialization\n",
1141
+ " json_data = {}\n",
1142
+ " \n",
1143
+ " for snr in snr_list:\n",
1144
+ " metrics = snr_metrics[snr]\n",
1145
+ " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
1146
+ " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
1147
+ " recall_at_05 = metrics[\"recall_counts\"][0.5] / metrics[\"total_samples\"][0.5] if metrics[\"total_samples\"][0.5] > 0 else 0\n",
1148
+ " recall_at_07 = metrics[\"recall_counts\"][0.7] / metrics[\"total_samples\"][0.7] if metrics[\"total_samples\"][0.7] > 0 else 0\n",
1149
+ " recall_at_09 = metrics[\"recall_counts\"][0.9] / metrics[\"total_samples\"][0.9] if metrics[\"total_samples\"][0.9] > 0 else 0\n",
1150
+ "\n",
1151
+ " # Append to lists for plotting\n",
1152
+ " accuracy_list.append(snr_accuracy)\n",
1153
+ " iou_list.append(snr_iou)\n",
1154
+ " recall_05.append(recall_at_05)\n",
1155
+ " recall_07.append(recall_at_07)\n",
1156
+ " recall_09.append(recall_at_09)\n",
1157
+ "\n",
1158
+ " # Prepare data for JSON\n",
1159
+ " json_data[snr] = {\n",
1160
+ " \"accuracy\": snr_accuracy,\n",
1161
+ " \"iou\": snr_iou,\n",
1162
+ " \"recall\": {\n",
1163
+ " \"0.5\": recall_at_05,\n",
1164
+ " \"0.7\": recall_at_07,\n",
1165
+ " \"0.9\": recall_at_09,\n",
1166
+ " }\n",
1167
+ " }\n",
1168
+ " \n",
1169
+ " # Save json_data to JSON file\n",
1170
+ " json_file_path = os.path.join(save_path, 'evaluation_results.json')\n",
1171
+ " with open(json_file_path, 'w') as json_file:\n",
1172
+ " json.dump(json_data, json_file, indent=4)\n",
1173
+ " \n",
1174
+ " # Plot Accuracy vs. SNR\n",
1175
+ " plt.figure(figsize=(10, 6))\n",
1176
+ " plt.plot(snr_list, accuracy_list, marker='o', label='Accuracy')\n",
1177
+ " plt.title('Accuracy vs. SNR')\n",
1178
+ " plt.xlabel('SNR (dB)')\n",
1179
+ " plt.ylabel('Accuracy (%)')\n",
1180
+ " plt.grid(True)\n",
1181
+ " plt.legend()\n",
1182
+ " \n",
1183
+ " # Set x-axis limits\n",
1184
+ " plt.xlim(-9, 12)\n",
1185
+ " \n",
1186
+ " # Save the plot\n",
1187
+ " accuracy_png_path = os.path.join(save_path, 'accuracy_vs_snr.png')\n",
1188
+ " accuracy_svg_path = os.path.join(save_path, 'accuracy_vs_snr.svg')\n",
1189
+ " plt.savefig(accuracy_png_path, format='png', bbox_inches='tight')\n",
1190
+ " plt.savefig(accuracy_svg_path, format='svg', bbox_inches='tight')\n",
1191
+ " \n",
1192
+ " plt.show()\n",
1193
+ " plt.close()\n",
1194
+ " \n",
1195
+ " # Plot IoU vs. SNR\n",
1196
+ " plt.figure(figsize=(10, 6))\n",
1197
+ " plt.plot(snr_list, iou_list, marker='o', color='orange', label='IoU')\n",
1198
+ " plt.title('IoU vs. SNR')\n",
1199
+ " plt.xlabel('SNR (dB)')\n",
1200
+ " plt.ylabel('IoU')\n",
1201
+ " plt.grid(True)\n",
1202
+ " plt.legend()\n",
1203
+ " \n",
1204
+ " # Set x-axis limits\n",
1205
+ " plt.xlim(-9, 12)\n",
1206
+ " \n",
1207
+ " # Save the plot\n",
1208
+ " iou_png_path = os.path.join(save_path, 'iou_vs_snr.png')\n",
1209
+ " iou_svg_path = os.path.join(save_path, 'iou_vs_snr.svg')\n",
1210
+ " plt.savefig(iou_png_path, format='png', bbox_inches='tight')\n",
1211
+ " plt.savefig(iou_svg_path, format='svg', bbox_inches='tight')\n",
1212
+ " \n",
1213
+ " plt.show()\n",
1214
+ " plt.close()\n",
1215
+ " \n",
1216
+ " # Plot Recall at Different IoU Thresholds vs. SNR\n",
1217
+ " plt.figure(figsize=(10, 6))\n",
1218
+ " plt.plot(snr_list, recall_05, marker='o', label='Recall @ IoU 0.5')\n",
1219
+ " plt.plot(snr_list, recall_07, marker='s', label='Recall @ IoU 0.7')\n",
1220
+ " plt.plot(snr_list, recall_09, marker='^', label='Recall @ IoU 0.9')\n",
1221
+ " plt.title('Recall at Different IoU Thresholds vs. SNR')\n",
1222
+ " plt.xlabel('SNR (dB)')\n",
1223
+ " plt.ylabel('Recall')\n",
1224
+ " plt.grid(True)\n",
1225
+ " plt.legend()\n",
1226
+ " \n",
1227
+ " # Set x-axis limits\n",
1228
+ " plt.xlim(-9, 12)\n",
1229
+ " \n",
1230
+ " # Save the plot\n",
1231
+ " recall_png_path = os.path.join(save_path, 'recall_vs_snr.png')\n",
1232
+ " recall_svg_path = os.path.join(save_path, 'recall_vs_snr.svg')\n",
1233
+ " plt.savefig(recall_png_path, format='png', bbox_inches='tight')\n",
1234
+ " plt.savefig(recall_svg_path, format='svg', bbox_inches='tight')\n",
1235
+ " \n",
1236
+ " plt.show()\n",
1237
+ " plt.close()\n"
1238
+ ]
1239
+ },
1240
+ {
1241
+ "cell_type": "code",
1242
+ "execution_count": null,
1243
+ "id": "c9595d5e",
1244
+ "metadata": {},
1245
+ "outputs": [],
1246
+ "source": [
1247
+ "# Assuming snr_metrics is the output from the evaluate function\n",
1248
+ "# Set the save path\n",
1249
+ "save_path = 'CMuSeNet_BIGRED_results'\n",
1250
+ "\n",
1251
+ "# Call the function\n",
1252
+ "save_results_and_plot(snr_metrics, save_path)\n"
1253
+ ]
1254
+ }
1255
+ ],
1256
+ "metadata": {
1257
+ "kernelspec": {
1258
+ "display_name": "Python 3 (ipykernel)",
1259
+ "language": "python",
1260
+ "name": "python3"
1261
+ },
1262
+ "language_info": {
1263
+ "codemirror_mode": {
1264
+ "name": "ipython",
1265
+ "version": 3
1266
+ },
1267
+ "file_extension": ".py",
1268
+ "mimetype": "text/x-python",
1269
+ "name": "python",
1270
+ "nbconvert_exporter": "python",
1271
+ "pygments_lexer": "ipython3",
1272
+ "version": "3.10.9"
1273
+ }
1274
+ },
1275
+ "nbformat": 4,
1276
+ "nbformat_minor": 5
1277
+ }
CMuSeNet_Indoor_OTA.ipynb ADDED
@@ -0,0 +1,1658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b5007b71",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Initialization"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "3e6b1226",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "### Initialization block\n",
19
+ "from pathlib import Path\n",
20
+ "import numpy as np\n",
21
+ "import json\n",
22
+ "import torch\n",
23
+ "import numpy as np\n",
24
+ "from tqdm import tqdm\n",
25
+ "import math\n",
26
+ "from torch.utils.data import DataLoader, TensorDataset\n",
27
+ "\n",
28
+ "STFT_LENGTH = 16 * 1024\n",
29
+ "DATA_DIR = Path(\"/data/OTA_reduced/\")\n",
30
+ "SAMPLE_RATE = 20e6\n",
31
+ "MODULATIONS = [\"QPSK\", \"BPSK\", \"2-FSK\"]\n",
32
+ "MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}\n",
33
+ "NUMBER_OF_MODULATIONS = len(MODULATIONS)\n",
34
+ "MASK_SIZE = int(STFT_LENGTH)\n",
35
+ "\n",
36
+ "from matplotlib.mlab import psd as apply_psd\n",
37
+ "\n",
38
+ "def calc_sig_power(signal, meta, noise_power=-132.065):\n",
39
+ " \n",
40
+ " noise_floor_linear = 10 ** (noise_power / 10)\n",
41
+ " (psd, frequencies) = apply_psd(signal, Fs=SAMPLE_RATE, NFFT=1024)\n",
42
+ "\n",
43
+ "\n",
44
+ " signal_position = []\n",
45
+ "\n",
46
+ " body = meta[\"body\"]\n",
47
+ " device = meta[\"client_id\"]\n",
48
+ " bandwidth, frequency_offset = body[\"bandwidth\"] + 20e3, body[\"frequency_offset\"]\n",
49
+ "\n",
50
+ " \n",
51
+ " below_freq = frequency_offset-bandwidth/2\n",
52
+ " upper_freq = frequency_offset+bandwidth/2\n",
53
+ " sum_power_dbs = 0\n",
54
+ " freq_count = 0\n",
55
+ " \n",
56
+ " for idx, (power, freq) in enumerate(zip(psd, frequencies)):\n",
57
+ " if below_freq <= freq <= upper_freq:\n",
58
+ " freq_count+=1\n",
59
+ " sum_power_dbs+=(power)\n",
60
+ " return sum_power_dbs\n",
61
+ "\n",
62
+ "# noise_power is measured from noise signal collection\n",
63
+ "def calc_snr(signal_power, noise_power=-132.065):\n",
64
+ " noise_floor_linear = 10 ** (noise_power / 10)\n",
65
+ " snr_linear = signal_power / (noise_floor_linear * 1024)\n",
66
+ " \n",
67
+ " snr_db = 10 * np.log10(snr_linear)\n",
68
+ " \n",
69
+ " return round(snr_db)\n",
70
+ "\n",
71
+ "def convert_metadata_format_real_to_simulated(signal, metadata):\n",
72
+ " name_mapping = {\"2FSK\": \"2-FSK\"}\n",
73
+ " return [\n",
74
+ " {\n",
75
+ " \"fc\": body[\"frequency_offset\"], \n",
76
+ " \"bw\": body[\"bandwidth\"] + 20e3,\n",
77
+ " \"mod\": name_mapping.get(body[\"modulation\"], body[\"modulation\"]),\n",
78
+ " \"snr\": calc_snr(calc_sig_power(signal, meta))\n",
79
+ " } for meta in metadata if (body := meta[\"body\"])\n",
80
+ " ]\n",
81
+ "\n",
82
+ "def load_data(signal_id, load_metadata_only=False):\n",
83
+ " if not load_metadata_only:\n",
84
+ " signal_path = DATA_DIR / str(signal_id) / \"data.npy\"\n",
85
+ " if not signal_path.exists():\n",
86
+ " raise FileNotFoundError(f\"Signal file {signal_path} not found.\")\n",
87
+ " signal = np.load(signal_path)\n",
88
+ " else:\n",
89
+ " signal = None\n",
90
+ " with open(DATA_DIR / str(signal_id) / \"meta-data.json\") as f:\n",
91
+ " meta = json.load(f)\n",
92
+ " if isinstance(meta, dict):\n",
93
+ " meta = [meta]\n",
94
+ " return signal, convert_metadata_format_real_to_simulated(signal, meta)\n",
95
+ "\n",
96
+ "\n",
97
+ " \n",
98
+ "def _get_all_numbered_dirs(root_dir):\n",
99
+ " dirs = []\n",
100
+ " for directory in root_dir.iterdir():\n",
101
+ " dirs.append(int(directory.name))\n",
102
+ " dirs.sort()\n",
103
+ " return dirs\n",
104
+ " \n",
105
+ " \n",
106
+ "def process_metadata(metadata):\n",
107
+ " scaled_metadata = [\n",
108
+ " {\n",
109
+ " \"position\": (SAMPLE_RATE/2 + i['fc'], i['bw']),\n",
110
+ " \"mod\": i[\"mod\"],\n",
111
+ " \"snr\": i[\"snr\"],\n",
112
+ " \"bw\": int(i['bw'])\n",
113
+ " }\n",
114
+ " for i in metadata\n",
115
+ " ]\n",
116
+ " return scaled_metadata\n",
117
+ "\n",
118
+ "\n",
119
+ "def process_signal(signal):\n",
120
+ " signal = signal[:STFT_LENGTH]\n",
121
+ "\n",
122
+ " signal = np.fft.fft(signal)\n",
123
+ " signal = np.fft.fftshift(signal)\n",
124
+ " signal /= np.max(np.abs(signal))\n",
125
+ " return signal"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "440b802c",
131
+ "metadata": {},
132
+ "source": [
133
+ "### Data Loading"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "31bc3770",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "class WidebandSignalDataset(torch.utils.data.Dataset):\n",
144
+ " def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snrs=False):\n",
145
+ " self.mask_size = mask_size\n",
146
+ " self.signal_ids = signal_ids\n",
147
+ " self.return_snrs = return_snrs\n",
148
+ " self.snrs = []\n",
149
+ " loaded_data = []\n",
150
+ " \n",
151
+ " for signal_id in tqdm(self.signal_ids):\n",
152
+ " loaded_data.append(self.process_signal(signal_id))\n",
153
+ " \n",
154
+ " self.loaded_data = loaded_data\n",
155
+ "\n",
156
+ " def __len__(self):\n",
157
+ " return len(self.signal_ids)\n",
158
+ "\n",
159
+ " def __getitem__(self, index):\n",
160
+ " if self.return_snrs:\n",
161
+ " signal, masks, snr = self.loaded_data[index]\n",
162
+ " else:\n",
163
+ " signal, masks = self.loaded_data[index]\n",
164
+ "\n",
165
+ " # Ensure `signal` is complex and `masks` is real-valued\n",
166
+ " if not isinstance(signal, torch.Tensor):\n",
167
+ " signal = torch.from_numpy(signal).type(torch.complex64)\n",
168
+ " if not isinstance(masks, torch.Tensor):\n",
169
+ " masks = torch.from_numpy(masks).type(torch.FloatTensor)\n",
170
+ "\n",
171
+ " if self.return_snrs:\n",
172
+ " if not isinstance(snr, torch.Tensor):\n",
173
+ " snr = torch.tensor(snr).type(torch.FloatTensor)\n",
174
+ " return signal, masks, snr\n",
175
+ " else:\n",
176
+ " return signal, masks\n",
177
+ "\n",
178
+ " def process_signal(self, signal_id):\n",
179
+ " # Load data and metadata\n",
180
+ " signal, metadata = load_data(signal_id)\n",
181
+ " \n",
182
+ " # Process the metadata and create masks\n",
183
+ " scaled_metadata = process_metadata(metadata)\n",
184
+ " snrs = [meta['snr'] for meta in scaled_metadata]\n",
185
+ " average_snr = sum(snrs) / len(snrs) if snrs else 0\n",
186
+ " \n",
187
+ " # Convert signal to complex format and normalize it\n",
188
+ " signal = process_signal(signal) # `process_signal` should return np.ndarray (complex)\n",
189
+ " signal = torch.from_numpy(signal).type(torch.complex64) # Convert to complex tensor\n",
190
+ " \n",
191
+ " # Generate binary mask for each frequency segment\n",
192
+ " masks = np.zeros(self.mask_size, dtype=np.float32)\n",
193
+ " scale_ratio = self.mask_size / SAMPLE_RATE\n",
194
+ " for meta in scaled_metadata:\n",
195
+ " f, b = meta['position']\n",
196
+ " x1 = math.floor((f - b / 2) * scale_ratio)\n",
197
+ " x2 = math.ceil((f + b / 2) * scale_ratio)\n",
198
+ " masks[x1:x2] = 1\n",
199
+ " \n",
200
+ " if self.return_snrs:\n",
201
+ " return signal, masks, average_snr\n",
202
+ " else:\n",
203
+ " return signal, masks\n",
204
+ "\n",
205
+ "\n",
206
+ "# Train test split 80 - 10 - 10\n",
207
+ "train, test, validation = [], [], [] \n",
208
+ "total_signals = len([i for i in DATA_DIR.iterdir()])\n",
209
+ "for index, signal in enumerate(_get_all_numbered_dirs(DATA_DIR)):\n",
210
+ " if index <= 0.80 * total_signals:\n",
211
+ " train.append(signal)\n",
212
+ " elif index <= 0.9 * total_signals:\n",
213
+ " validation.append(signal)\n",
214
+ " else:\n",
215
+ " test.append(signal)\n",
216
+ " \n",
217
+ "print(\"Train\", len(train))\n",
218
+ "print(\"Validation\", len(validation))\n",
219
+ "print(\"Test\", len(test))\n"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "id": "3e74df1a",
225
+ "metadata": {},
226
+ "source": [
227
+ "### Check if complex value"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "23f75344",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "def test_single_signal_loading(signal_id):\n",
238
+ " # Load a single signal and process it\n",
239
+ " signal, metadata = load_data(signal_id)\n",
240
+ " \n",
241
+ " # Process the signal: Apply any necessary preprocessing, and convert to complex format\n",
242
+ " processed_signal = process_signal(signal) # This should return a complex np.ndarray\n",
243
+ " complex_signal = torch.from_numpy(processed_signal).type(torch.complex64)\n",
244
+ " \n",
245
+ " # Check if the signal is complex\n",
246
+ " print(\"Loaded Signal ID:\", signal_id)\n",
247
+ " print(\"Signal Type:\", complex_signal.dtype)\n",
248
+ " print(\"Signal Shape:\", complex_signal.shape)\n",
249
+ " \n",
250
+ " # Generate the mask as you would in WidebandSignalDataset\n",
251
+ " scaled_metadata = process_metadata(metadata)\n",
252
+ " masks = np.zeros(MASK_SIZE, dtype=np.float32)\n",
253
+ " scale_ratio = MASK_SIZE / SAMPLE_RATE\n",
254
+ " for meta in scaled_metadata:\n",
255
+ " f, b = meta['position']\n",
256
+ " x1 = math.floor((f - b / 2) * scale_ratio)\n",
257
+ " x2 = math.ceil((f + b / 2) * scale_ratio)\n",
258
+ " masks[x1:x2] = 1\n",
259
+ "\n",
260
+ " # Convert mask to tensor\n",
261
+ " mask_tensor = torch.from_numpy(masks).type(torch.FloatTensor)\n",
262
+ "\n",
263
+ " # Output information about the mask\n",
264
+ " print(\"Mask Shape:\", mask_tensor.shape)\n",
265
+ " print(\"Mask Type:\", mask_tensor.dtype)\n",
266
+ " \n",
267
+ " return complex_signal, mask_tensor\n",
268
+ "\n",
269
+ "# Test with a specific signal_id (replace with an actual ID from your data)\n",
270
+ "test_signal_id = train[0] # Assuming `train` list contains valid signal IDs\n",
271
+ "complex_signal, mask_tensor = test_single_signal_loading(test_signal_id)\n",
272
+ "\n",
273
+ "# Optional: Check a sample value to confirm it's complex\n",
274
+ "print(\"Sample value from signal tensor:\", complex_signal[0])"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "1cec9c6e",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "train_dataset = WidebandSignalDataset(signal_ids=train)\n",
285
+ "validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
286
+ "test_dataset = WidebandSignalDataset(signal_ids=test)"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "id": "e0900d4e",
292
+ "metadata": {},
293
+ "source": [
294
+ "### Check SNR"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "id": "2fbee106",
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "import matplotlib.pyplot as plt\n",
305
+ "\n",
306
+ "# For Train Dataset\n",
307
+ "train_snrs = train_dataset.snrs\n",
308
+ "\n",
309
+ "# Plot Histogram of SNRs in Train Dataset\n",
310
+ "plt.figure(figsize=(10, 6))\n",
311
+ "plt.hist(train_snrs, bins=range(int(min(train_snrs)), int(max(train_snrs)) + 1), edgecolor='black')\n",
312
+ "plt.title('Histogram of SNRs in Train Dataset')\n",
313
+ "plt.xlabel('SNR (dB)')\n",
314
+ "plt.ylabel('Number of Samples')\n",
315
+ "plt.grid(True)\n",
316
+ "plt.show()\n",
317
+ "\n",
318
+ "# Print SNR Range\n",
319
+ "print('Train Dataset SNR range: {} dB to {} dB'.format(min(train_snrs), max(train_snrs)))\n",
320
+ "\n",
321
+ "# For Validation Dataset\n",
322
+ "validation_snrs = validation_dataset.snrs\n",
323
+ "\n",
324
+ "# Plot Histogram of SNRs in Validation Dataset\n",
325
+ "plt.figure(figsize=(10, 6))\n",
326
+ "plt.hist(validation_snrs, bins=range(int(min(validation_snrs)), int(max(validation_snrs)) + 1), edgecolor='black')\n",
327
+ "plt.title('Histogram of SNRs in Validation Dataset')\n",
328
+ "plt.xlabel('SNR (dB)')\n",
329
+ "plt.ylabel('Number of Samples')\n",
330
+ "plt.grid(True)\n",
331
+ "plt.show()\n",
332
+ "\n",
333
+ "# Print SNR Range\n",
334
+ "print('Validation Dataset SNR range: {} dB to {} dB'.format(min(validation_snrs), max(validation_snrs)))\n",
335
+ "\n",
336
+ "# For Test Dataset\n",
337
+ "test_snrs = test_dataset.snrs\n",
338
+ "\n",
339
+ "# Plot Histogram of SNRs in Validation Dataset\n",
340
+ "plt.figure(figsize=(10, 6))\n",
341
+ "plt.hist(test_snrs, bins=range(int(min(test_snrs)), int(max(test_snrs)) + 1), edgecolor='black')\n",
342
+ "plt.title('Histogram of SNRs in Test Dataset')\n",
343
+ "plt.xlabel('SNR (dB)')\n",
344
+ "plt.ylabel('Number of Samples')\n",
345
+ "plt.grid(True)\n",
346
+ "plt.show()\n",
347
+ "\n",
348
+ "# Print SNR Range\n",
349
+ "print('Validation Dataset SNR range: {} dB to {} dB'.format(min(test_snrs), max(test_snrs)))\n"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "id": "637ae774",
355
+ "metadata": {},
356
+ "source": [
357
+ "### Batch Loading"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "id": "a9af2450",
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "batch_size = 64 # Updated batch size\n",
368
+ "\n",
369
+ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
370
+ "valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)\n",
371
+ "\n",
372
+ "print(\"Train labels shape:\", len(train_dataset))\n",
373
+ "print(\"Validation labels shape:\", len(validation_dataset))"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "markdown",
378
+ "id": "9a8e09e4",
379
+ "metadata": {},
380
+ "source": [
381
+ "### Early Stop"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "id": "24f79a24",
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "import os\n",
392
+ "\n",
393
+ "class EarlyStopping:\n",
394
+ " def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./path/to/model/save'):\n",
395
+ " self.patience = patience\n",
396
+ " self.verbose = verbose\n",
397
+ " self.delta = delta\n",
398
+ " self.counter = 0\n",
399
+ " self.best_score = None\n",
400
+ " self.early_stop = False\n",
401
+ " self.val_loss_min = float('inf')\n",
402
+ " self.best_model = None\n",
403
+ " self.save_path = save_path\n",
404
+ " os.makedirs(save_path, exist_ok=True)\n",
405
+ " \n",
406
+ " def __call__(self, val_loss, model):\n",
407
+ " score = -val_loss\n",
408
+ "\n",
409
+ " if self.best_score is None:\n",
410
+ " self.best_score = score\n",
411
+ " self.save_checkpoint(val_loss, model)\n",
412
+ " elif score < self.best_score + self.delta:\n",
413
+ " self.counter += 1\n",
414
+ " if self.verbose:\n",
415
+ " print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
416
+ " if self.counter >= self.patience:\n",
417
+ " self.early_stop = True\n",
418
+ " else:\n",
419
+ " self.best_score = score\n",
420
+ " self.save_checkpoint(val_loss, model)\n",
421
+ " self.counter = 0\n",
422
+ "\n",
423
+ " def save_checkpoint(self, val_loss, model):\n",
424
+ " if self.verbose:\n",
425
+ " print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
426
+ " self.val_loss_min = val_loss\n",
427
+ " self.best_model = model.state_dict()\n",
428
+ " save_path = os.path.join(self.save_path, 'best_model.pth')\n",
429
+ " torch.save(self.best_model, save_path)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "markdown",
434
+ "id": "6c3fda74",
435
+ "metadata": {},
436
+ "source": [
437
+ "### Reshape"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "id": "5fcf91db",
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "import torch.nn as nn\n",
448
+ "import complexPyTorch.complexLayers as cplx\n",
449
+ "import torch.nn.functional as F\n",
450
+ "import torch\n",
451
+ "\n",
452
+ "def reshape_to_2d(data):\n",
453
+ " return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "markdown",
458
+ "id": "b7d7562c",
459
+ "metadata": {},
460
+ "source": [
461
+ "### Complex IoU"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": null,
467
+ "id": "76b9d084",
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "def calculate_iou(pred, target, threshold=0.5):\n",
472
+ " real_pred = (pred.real > threshold).float()\n",
473
+ " imag_pred = (pred.imag > threshold).float()\n",
474
+ " \n",
475
+ " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
476
+ " \n",
477
+ " intersection = (combined_pred * target).sum(dim=1)\n",
478
+ " union = (combined_pred + target).sum(dim=1) - intersection\n",
479
+ " iou = (intersection / union).mean().item()\n",
480
+ " return iou"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "markdown",
485
+ "id": "64f4063c",
486
+ "metadata": {},
487
+ "source": [
488
+ "### Training"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "id": "66825110",
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "import time\n",
499
+ "\n",
500
+ "def validate_model(model, valid_loader, criterion):\n",
501
+ " model.eval()\n",
502
+ " running_loss = 0.0\n",
503
+ " iou_scores = []\n",
504
+ " total_correct = 0\n",
505
+ " total_samples = 0\n",
506
+ "\n",
507
+ " with torch.no_grad():\n",
508
+ " for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
509
+ " inputs = reshape_to_2d(inputs).to(device)\n",
510
+ " masks = masks.to(device)\n",
511
+ " outputs = model(inputs)\n",
512
+ " loss = criterion(outputs, masks)\n",
513
+ " running_loss += loss.item()\n",
514
+ "\n",
515
+ " # Calculate IoU\n",
516
+ " iou = calculate_iou(outputs, masks, threshold=0.5)\n",
517
+ " iou_scores.append(iou)\n",
518
+ " \n",
519
+ " # Calculate accuracy\n",
520
+ " preds = (outputs.real > 0.5).float()\n",
521
+ " correct = (preds == masks).float().sum()\n",
522
+ " total_correct += correct.item()\n",
523
+ " total_samples += masks.numel()\n",
524
+ "\n",
525
+ " val_loss = running_loss / len(valid_loader)\n",
526
+ " mean_iou = sum(iou_scores) / len(iou_scores)\n",
527
+ " accuracy = total_correct / total_samples * 100\n",
528
+ "\n",
529
+ " print(f'Validation Loss: {val_loss:.6f}')\n",
530
+ " print(f'Validation Accuracy: {accuracy:.2f}%')\n",
531
+ "\n",
532
+ " return val_loss, accuracy\n",
533
+ "\n",
534
+ "def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.000001], num_epochs=50, patience=5):\n",
535
+ " train_losses = []\n",
536
+ " val_losses = []\n",
537
+ " val_accuracies = []\n",
538
+ " epoch_durations = []\n",
539
+ " \n",
540
+ " current_lr = initial_lr\n",
541
+ " for lr in lr_steps:\n",
542
+ " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
543
+ " early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
544
+ " print(\"Current learning rate: \", lr)\n",
545
+ " for epoch in range(num_epochs):\n",
546
+ " epoch_start_time = time.time()\n",
547
+ " \n",
548
+ " model.train()\n",
549
+ " running_loss = 0.0\n",
550
+ " for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
551
+ " inputs = reshape_to_2d(inputs).to(device)\n",
552
+ " masks = masks.to(device)\n",
553
+ " outputs = model(inputs)\n",
554
+ " loss = criterion(outputs, masks)\n",
555
+ "\n",
556
+ " optimizer.zero_grad()\n",
557
+ " loss.backward()\n",
558
+ " optimizer.step()\n",
559
+ "\n",
560
+ " running_loss += loss.item()\n",
561
+ "\n",
562
+ " epoch_loss = running_loss / len(train_loader)\n",
563
+ " train_losses.append(epoch_loss)\n",
564
+ " print(f\"Training Loss: {epoch_loss:.6f}\")\n",
565
+ " \n",
566
+ " val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
567
+ " val_losses.append(val_loss)\n",
568
+ " val_accuracies.append(val_accuracy)\n",
569
+ " early_stopping(val_loss, model)\n",
570
+ "\n",
571
+ " if early_stopping.early_stop:\n",
572
+ " print(\"Early stopping triggered\")\n",
573
+ " break\n",
574
+ "\n",
575
+ " epoch_duration = time.time() - epoch_start_time\n",
576
+ " epoch_durations.append(epoch_duration)\n",
577
+ " if early_stopping.best_model is not None:\n",
578
+ " print(f\"Loading best model from lr {lr}\")\n",
579
+ " model.load_state_dict(early_stopping.best_model)\n",
580
+ " \n",
581
+ " print(\"Training completed.\")\n",
582
+ " print(\"Epoch durations:\", epoch_durations)\n",
583
+ " return model, train_losses, val_losses, val_accuracies, epoch_durations"
584
+ ]
585
+ },
586
+ {
587
+ "cell_type": "markdown",
588
+ "id": "0b80cb51",
589
+ "metadata": {},
590
+ "source": [
591
+ "### ResNet-18"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": null,
597
+ "id": "2d208cb9",
598
+ "metadata": {},
599
+ "outputs": [],
600
+ "source": [
601
+ "import torch\n",
602
+ "import torch.nn as nn\n",
603
+ "import complexPyTorch.complexLayers as cplx\n",
604
+ "from typing import Optional, Callable, Type, Union, List\n",
605
+ "import torch.nn.functional as F\n",
606
+ "from torch import Tensor\n",
607
+ "\n",
608
+ "def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
609
+ " \"\"\"3x3 convolution with padding\"\"\"\n",
610
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
611
+ "\n",
612
+ "def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
613
+ " \"\"\"1x1 convolution\"\"\"\n",
614
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
615
+ "\n",
616
+ "class BasicBlock(nn.Module):\n",
617
+ " expansion = 1\n",
618
+ "\n",
619
+ " def __init__(\n",
620
+ " self,\n",
621
+ " inplanes: int,\n",
622
+ " planes: int,\n",
623
+ " stride: int = 1,\n",
624
+ " downsample: Optional[nn.Module] = None,\n",
625
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
626
+ " ) -> None:\n",
627
+ " super(BasicBlock, self).__init__()\n",
628
+ " self.conv1 = conv3x3(inplanes, planes, stride)\n",
629
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
630
+ " self.relu = cplx.ComplexReLU()\n",
631
+ " self.conv2 = conv3x3(planes, planes)\n",
632
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
633
+ " self.downsample = downsample\n",
634
+ " self.stride = stride\n",
635
+ "\n",
636
+ " def forward(self, x: Tensor) -> Tensor:\n",
637
+ " identity = x\n",
638
+ "\n",
639
+ " out = self.conv1(x)\n",
640
+ " out = self.bn1(out)\n",
641
+ " out = self.relu(out)\n",
642
+ "\n",
643
+ " out = self.conv2(out)\n",
644
+ " out = self.bn2(out)\n",
645
+ "\n",
646
+ " if self.downsample is not None:\n",
647
+ " identity = self.downsample(x)\n",
648
+ "\n",
649
+ " out += identity\n",
650
+ " out = self.relu(out)\n",
651
+ "\n",
652
+ " return out\n",
653
+ "\n",
654
+ "class Bottleneck(nn.Module):\n",
655
+ " expansion = 4\n",
656
+ "\n",
657
+ " def __init__(\n",
658
+ " self,\n",
659
+ " inplanes: int,\n",
660
+ " planes: int,\n",
661
+ " stride: int = 1,\n",
662
+ " downsample: Optional[nn.Module] = None,\n",
663
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
664
+ " ) -> None:\n",
665
+ " super(Bottleneck, self).__init__()\n",
666
+ " self.conv1 = conv1x1(inplanes, planes)\n",
667
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
668
+ " self.conv2 = conv3x3(planes, planes, stride)\n",
669
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
670
+ " self.conv3 = conv1x1(planes, planes * self.expansion)\n",
671
+ " self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
672
+ " self.relu = cplx.ComplexReLU()\n",
673
+ " self.downsample = downsample\n",
674
+ " self.stride = stride\n",
675
+ "\n",
676
+ " def forward(self, x: Tensor) -> Tensor:\n",
677
+ " identity = x\n",
678
+ "\n",
679
+ " out = self.conv1(x)\n",
680
+ " out = self.bn1(out)\n",
681
+ " out = self.relu(out)\n",
682
+ "\n",
683
+ " out = self.conv2(out)\n",
684
+ " out = self.bn2(out)\n",
685
+ " out = self.relu(out)\n",
686
+ "\n",
687
+ " out = self.conv3(out)\n",
688
+ " out = self.bn3(out)\n",
689
+ "\n",
690
+ " if self.downsample is not None:\n",
691
+ " identity = self.downsample(x)\n",
692
+ "\n",
693
+ " out += identity\n",
694
+ " out = self.relu(out)\n",
695
+ "\n",
696
+ " return out\n",
697
+ "\n",
698
+ "class ComplexResNet(nn.Module):\n",
699
+ " def __init__(\n",
700
+ " self,\n",
701
+ " block: Type[Union[BasicBlock, Bottleneck]],\n",
702
+ " layers: List[int],\n",
703
+ " num_classes: int = STFT_LENGTH,\n",
704
+ " zero_init_residual: bool = False,\n",
705
+ " groups: int = 1,\n",
706
+ " width_per_group: int = 64,\n",
707
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
708
+ " ) -> None:\n",
709
+ " super(ComplexResNet, self).__init__()\n",
710
+ " if norm_layer is None:\n",
711
+ " norm_layer = cplx.ComplexBatchNorm2d\n",
712
+ " self._norm_layer = norm_layer\n",
713
+ "\n",
714
+ " self.inplanes = 64\n",
715
+ " self.dilation = 1\n",
716
+ "\n",
717
+ " self.groups = groups\n",
718
+ " self.base_width = width_per_group\n",
719
+ " self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
720
+ " self.bn1 = norm_layer(self.inplanes)\n",
721
+ " self.relu = cplx.ComplexReLU()\n",
722
+ " self.layer1 = self._make_layer(block, 64, layers[0])\n",
723
+ " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
724
+ " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
725
+ " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
726
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
727
+ " self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
728
+ " self.sigmoid = cplx.ComplexSigmoid()\n",
729
+ "\n",
730
+ " def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
731
+ " norm_layer = self._norm_layer\n",
732
+ " downsample = None\n",
733
+ " if stride != 1 or self.inplanes != planes * block.expansion:\n",
734
+ " downsample = nn.Sequential(\n",
735
+ " conv1x1(self.inplanes, planes * block.expansion, stride),\n",
736
+ " norm_layer(planes * block.expansion),\n",
737
+ " )\n",
738
+ "\n",
739
+ " layers = []\n",
740
+ " layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
741
+ " self.inplanes = planes * block.expansion\n",
742
+ " for _ in range(1, blocks):\n",
743
+ " layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
744
+ "\n",
745
+ " return nn.Sequential(*layers)\n",
746
+ "\n",
747
+ " def _forward_impl(self, x: Tensor) -> Tensor:\n",
748
+ " x = self.conv1(x)\n",
749
+ " x = self.bn1(x)\n",
750
+ " x = self.relu(x)\n",
751
+ "\n",
752
+ " x = self.layer1(x)\n",
753
+ " x = self.layer2(x)\n",
754
+ " x = self.layer3(x)\n",
755
+ " x = self.layer4(x)\n",
756
+ "\n",
757
+ " x = self.avgpool(x)\n",
758
+ " x = torch.flatten(x, 1)\n",
759
+ " x = self.fc(x)\n",
760
+ " x = self.sigmoid(x)\n",
761
+ " return x\n",
762
+ "\n",
763
+ " def forward(self, x: Tensor) -> Tensor:\n",
764
+ " return self._forward_impl(x)\n",
765
+ "\n",
766
+ "def ComplexResNet18():\n",
767
+ " return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
768
+ "\n",
769
+ "# Create the model instance\n",
770
+ "model = ComplexResNet18()\n",
771
+ "print(model)\n"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "markdown",
776
+ "id": "e4bc1b5d",
777
+ "metadata": {},
778
+ "source": [
779
+ "### Complex focal Loss"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "execution_count": null,
785
+ "id": "61c29429",
786
+ "metadata": {},
787
+ "outputs": [],
788
+ "source": [
789
+ "class ComplexFocalLoss(nn.Module):\n",
790
+ " def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
791
+ " super(ComplexFocalLoss, self).__init__()\n",
792
+ " self.alpha = alpha\n",
793
+ " self.gamma = gamma\n",
794
+ " self.reduction = reduction\n",
795
+ "\n",
796
+ " def forward(self, inputs, targets):\n",
797
+ " real_inputs = inputs.real\n",
798
+ " imag_inputs = inputs.imag\n",
799
+ " \n",
800
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
801
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
802
+ " \n",
803
+ " real_pt = torch.exp(-real_BCE_loss)\n",
804
+ " imag_pt = torch.exp(-imag_BCE_loss)\n",
805
+ " \n",
806
+ " real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
807
+ " imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
808
+ "\n",
809
+ " if self.reduction == 'mean':\n",
810
+ " return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
811
+ " elif self.reduction == 'sum':\n",
812
+ " return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
813
+ " else:\n",
814
+ " return real_F_loss + imag_F_loss"
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "markdown",
819
+ "id": "abb35ba2",
820
+ "metadata": {},
821
+ "source": [
822
+ "### Training with complex focal loss"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": null,
828
+ "id": "86d7526b",
829
+ "metadata": {
830
+ "scrolled": false
831
+ },
832
+ "outputs": [],
833
+ "source": [
834
+ "# Initialize and train the ResNet-18 model\n",
835
+ "model = ComplexResNet18().to(device)\n",
836
+ "criterion = ComplexFocalLoss()\n",
837
+ "\n",
838
+ "model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
839
+ "combined_epoch_time = sum(epoch_durations)\n",
840
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "markdown",
845
+ "id": "fd0c9d58",
846
+ "metadata": {},
847
+ "source": [
848
+ "### CVNN RV-BCE and CV-BCE Loss function implementation"
849
+ ]
850
+ },
851
+ {
852
+ "cell_type": "code",
853
+ "execution_count": null,
854
+ "id": "99c736b8",
855
+ "metadata": {},
856
+ "outputs": [],
857
+ "source": [
858
+ "# CV BCE Loss Function Definition\n",
859
+ "class ComplexValuedBCELoss(nn.Module):\n",
860
+ " def __init__(self, reduction='mean'):\n",
861
+ " super(ComplexValuedBCELoss, self).__init__()\n",
862
+ " self.reduction = reduction\n",
863
+ "\n",
864
+ " def forward(self, inputs, targets):\n",
865
+ " real_inputs = inputs.real\n",
866
+ " imag_inputs = inputs.imag\n",
867
+ "\n",
868
+ " # Calculate binary cross-entropy for both real and imaginary parts\n",
869
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
870
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
871
+ " \n",
872
+ " # Combine the losses (you can adjust the weighting if necessary)\n",
873
+ " combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
874
+ " return combined_BCE_loss"
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "markdown",
879
+ "id": "93d19ea7",
880
+ "metadata": {},
881
+ "source": [
882
+ "### CV-BCE Training"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": null,
888
+ "id": "2c56d5b4",
889
+ "metadata": {
890
+ "scrolled": false
891
+ },
892
+ "outputs": [],
893
+ "source": [
894
+ "# Set the criterion for CV BCE\n",
895
+ "criterion = ComplexValuedBCELoss()\n",
896
+ "\n",
897
+ "# Train the ResNet-18 model with CV BCE\n",
898
+ "device = torch.device('cuda')\n",
899
+ "model = ComplexResNet18().to(device)\n",
900
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
901
+ "\n",
902
+ "# Start training with the previously defined train_model function\n",
903
+ "model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
904
+ " model, train_loader, valid_loader, criterion, \n",
905
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
906
+ ")\n",
907
+ "combined_epoch_time = sum(epoch_durations)\n",
908
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
909
+ ]
910
+ },
911
+ {
912
+ "cell_type": "markdown",
913
+ "id": "7ccd50ff",
914
+ "metadata": {},
915
+ "source": [
916
+ "### Save and Plot"
917
+ ]
918
+ },
919
+ {
920
+ "cell_type": "code",
921
+ "execution_count": null,
922
+ "id": "eb41b92f",
923
+ "metadata": {},
924
+ "outputs": [],
925
+ "source": [
926
+ "import os\n",
927
+ "import json\n",
928
+ "import matplotlib.pyplot as plt\n",
929
+ "\n",
930
+ "# Define save directory\n",
931
+ "save_dir = 'CMuSeNet_results/segmentation_OTA'\n",
932
+ "\n",
933
+ "# Create the directory if it doesn't exist\n",
934
+ "os.makedirs(save_dir, exist_ok=True)\n",
935
+ "\n",
936
+ "# Plot training loss\n",
937
+ "plt.figure()\n",
938
+ "plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')\n",
939
+ "plt.title('Training Loss')\n",
940
+ "plt.xlabel('Epoch')\n",
941
+ "plt.ylabel('Loss')\n",
942
+ "plt.legend()\n",
943
+ "\n",
944
+ "# Save the training loss figure as PNG and SVG\n",
945
+ "plt.savefig(os.path.join(save_dir, 'training_loss.png'))\n",
946
+ "plt.savefig(os.path.join(save_dir, 'training_loss.svg'))\n",
947
+ "\n",
948
+ "# Show the training loss plot\n",
949
+ "plt.show()\n",
950
+ "\n",
951
+ "# Plot validation accuracy\n",
952
+ "plt.figure()\n",
953
+ "plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='green')\n",
954
+ "plt.title('Validation Accuracy')\n",
955
+ "plt.xlabel('Epoch')\n",
956
+ "plt.ylabel('Accuracy')\n",
957
+ "plt.legend()\n",
958
+ "\n",
959
+ "# Save the validation accuracy figure as PNG and SVG\n",
960
+ "plt.savefig(os.path.join(save_dir, 'validation_accuracy.png'))\n",
961
+ "plt.savefig(os.path.join(save_dir, 'validation_accuracy.svg'))\n",
962
+ "\n",
963
+ "# Show the validation accuracy plot\n",
964
+ "plt.show()\n",
965
+ "\n",
966
+ "# Save the actual data to a JSON file\n",
967
+ "results = {\n",
968
+ " \"train_losses\": train_losses,\n",
969
+ " \"val_accuracies\": val_accuracies,\n",
970
+ " \"epoch_durations\": epoch_durations,\n",
971
+ " \"combined_epoch_time\": combined_epoch_time\n",
972
+ "}\n",
973
+ "\n",
974
+ "# Save JSON file\n",
975
+ "with open(os.path.join(save_dir, 'training_validation_results.json'), 'w') as f:\n",
976
+ " json.dump(results, f)"
977
+ ]
978
+ },
979
+ {
980
+ "cell_type": "markdown",
981
+ "id": "3a757949",
982
+ "metadata": {},
983
+ "source": [
984
+ "### Transfer Learning from Synthetic model"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "markdown",
989
+ "id": "ee265d28",
990
+ "metadata": {},
991
+ "source": [
992
+ "### Load pre-trained model"
993
+ ]
994
+ },
995
+ {
996
+ "cell_type": "code",
997
+ "execution_count": null,
998
+ "id": "0dec6746",
999
+ "metadata": {},
1000
+ "outputs": [],
1001
+ "source": [
1002
+ "# Block to load pre-trained model and prepare for transfer learning\n",
1003
+ "device = torch.device(\"cuda\")\n",
1004
+ "\n",
1005
+ "# Load the pre-trained model\n",
1006
+ "\n",
1007
+ "model_path = \"path/to/model/save.pth\"\n",
1008
+ "model = ComplexResNet18().to(device)\n",
1009
+ "model.load_state_dict(torch.load(model_path, map_location=device))\n",
1010
+ "\n",
1011
+ "# Freeze all layers except the final layer\n",
1012
+ "for param in model.parameters():\n",
1013
+ " param.requires_grad = False\n",
1014
+ "\n",
1015
+ "# Modify the final layer for transfer learning (adjust `num_classes` as needed)\n",
1016
+ "num_classes = STFT_LENGTH # Set based on your current task\n",
1017
+ "model.fc = cplx.ComplexLinear(512 * BasicBlock.expansion, num_classes).to(device)\n",
1018
+ "\n",
1019
+ "# Unfreeze the final layer for training\n",
1020
+ "for param in model.fc.parameters():\n",
1021
+ " param.requires_grad = True\n"
1022
+ ]
1023
+ },
1024
+ {
1025
+ "cell_type": "markdown",
1026
+ "id": "21e1e62b",
1027
+ "metadata": {},
1028
+ "source": [
1029
+ "### Complex Learning for Transfer Learning (Same as above but easier access)"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "id": "4c6656d0",
1036
+ "metadata": {},
1037
+ "outputs": [],
1038
+ "source": [
1039
+ "class ComplexFocalLoss(nn.Module):\n",
1040
+ " def __init__(self, alpha=0.5, gamma=2, reduction='mean'):\n",
1041
+ " super(ComplexFocalLoss, self).__init__()\n",
1042
+ " self.alpha = alpha\n",
1043
+ " self.gamma = gamma\n",
1044
+ " self.reduction = reduction\n",
1045
+ "\n",
1046
+ " def forward(self, inputs, targets):\n",
1047
+ " real_inputs = inputs.real\n",
1048
+ " imag_inputs = inputs.imag\n",
1049
+ " \n",
1050
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
1051
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
1052
+ " \n",
1053
+ " real_pt = torch.exp(-real_BCE_loss)\n",
1054
+ " imag_pt = torch.exp(-imag_BCE_loss)\n",
1055
+ " \n",
1056
+ " real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
1057
+ " imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
1058
+ "\n",
1059
+ " if self.reduction == 'mean':\n",
1060
+ " return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
1061
+ " elif self.reduction == 'sum':\n",
1062
+ " return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
1063
+ " else:\n",
1064
+ " return real_F_loss + imag_F_loss\n",
1065
+ "\n",
1066
+ "# Update the IoU calculation to handle complex values\n",
1067
+ "def calculate_iou(pred, target, threshold=0.5):\n",
1068
+ " real_pred = (pred.real > threshold).float()\n",
1069
+ " imag_pred = (pred.imag > threshold).float()\n",
1070
+ " \n",
1071
+ " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
1072
+ " \n",
1073
+ " intersection = (combined_pred * target).sum(dim=1)\n",
1074
+ " union = (combined_pred + target).sum(dim=1) - intersection\n",
1075
+ " iou = (intersection / union).mean().item()\n",
1076
+ " return iou"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "markdown",
1081
+ "id": "bc9b7701",
1082
+ "metadata": {},
1083
+ "source": [
1084
+ "### Transfer Learning"
1085
+ ]
1086
+ },
1087
+ {
1088
+ "cell_type": "code",
1089
+ "execution_count": null,
1090
+ "id": "c291a42e",
1091
+ "metadata": {
1092
+ "scrolled": false
1093
+ },
1094
+ "outputs": [],
1095
+ "source": [
1096
+ "# Define a new criterion and optimizer for fine-tuning\n",
1097
+ "# You may select between Focal Loss or BCE as your criterion\n",
1098
+ "#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
1099
+ "criterion = ComplexFocalLoss()\n",
1100
+ "# Use a smaller learning rate for fine-tuning\n",
1101
+ "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
1102
+ "\n",
1103
+ "# Train the model (fine-tuning)\n",
1104
+ "model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
1105
+ " model, train_loader, valid_loader, criterion,\n",
1106
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
1107
+ ")\n",
1108
+ "combined_epoch_time = sum(epoch_durations)\n",
1109
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
1110
+ ]
1111
+ },
1112
+ {
1113
+ "cell_type": "markdown",
1114
+ "id": "98f81acc",
1115
+ "metadata": {},
1116
+ "source": [
1117
+ "## Transfer Transfer Learning (Different Radio)"
1118
+ ]
1119
+ },
1120
+ {
1121
+ "cell_type": "code",
1122
+ "execution_count": null,
1123
+ "id": "55017794",
1124
+ "metadata": {},
1125
+ "outputs": [],
1126
+ "source": [
1127
+ "# Block to load pre-trained model and prepare for transfer learning\n",
1128
+ "device = torch.device(\"cuda\")\n",
1129
+ "\n",
1130
+ "model_path = \"/path/to/model/save.pth\"\n",
1131
+ "model = ComplexResNet18().to(device)\n",
1132
+ "#model = ComplexValuedBCELoss().to(device)\n",
1133
+ "model.load_state_dict(torch.load(model_path, map_location=device))\n",
1134
+ "\n",
1135
+ "# Freeze all layers except the final layer\n",
1136
+ "for param in model.parameters():\n",
1137
+ " param.requires_grad = False\n",
1138
+ "\n",
1139
+ "# Modify the final layer for transfer learning (adjust `num_classes` as needed)\n",
1140
+ "num_classes = STFT_LENGTH # Set based on your current task\n",
1141
+ "model.fc = cplx.ComplexLinear(512 * BasicBlock.expansion, num_classes).to(device)\n",
1142
+ "\n",
1143
+ "# Unfreeze the final layer for training\n",
1144
+ "for param in model.fc.parameters():\n",
1145
+ " param.requires_grad = True\n"
1146
+ ]
1147
+ },
1148
+ {
1149
+ "cell_type": "code",
1150
+ "execution_count": null,
1151
+ "id": "5933b01f",
1152
+ "metadata": {},
1153
+ "outputs": [],
1154
+ "source": [
1155
+ "# Define a new criterion and optimizer for fine-tuning\n",
1156
+ "# You may select between Focal Loss or BCE as your criterion\n",
1157
+ "#criterion = ComplexValuedBCELoss() # or ComplexValuedBCELoss()\n",
1158
+ "criterion = ComplexFocalLoss()\n",
1159
+ "# Use a smaller learning rate for fine-tuning\n",
1160
+ "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n",
1161
+ "\n",
1162
+ "# Train the model (fine-tuning)\n",
1163
+ "model, train_losses, val_losses, val_accuracies, epoch_durations= train_model(\n",
1164
+ " model, train_loader, valid_loader, criterion,\n",
1165
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=1\n",
1166
+ ")\n",
1167
+ "combined_epoch_time = sum(epoch_durations)\n",
1168
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
1169
+ ]
1170
+ },
1171
+ {
1172
+ "cell_type": "markdown",
1173
+ "id": "dbef8bad",
1174
+ "metadata": {},
1175
+ "source": [
1176
+ "### Evaluation CVNN OTA"
1177
+ ]
1178
+ },
1179
+ {
1180
+ "cell_type": "code",
1181
+ "execution_count": null,
1182
+ "id": "d0ac03b7",
1183
+ "metadata": {},
1184
+ "outputs": [],
1185
+ "source": [
1186
+ "import torch\n",
1187
+ "from tqdm import tqdm\n",
1188
+ "import numpy as np\n",
1189
+ "from collections import defaultdict\n",
1190
+ "import torch.nn.functional as F\n",
1191
+ "from scipy.optimize import linear_sum_assignment\n",
1192
+ "from torch.utils.data import ConcatDataset"
1193
+ ]
1194
+ },
1195
+ {
1196
+ "cell_type": "code",
1197
+ "execution_count": null,
1198
+ "id": "f831e874",
1199
+ "metadata": {},
1200
+ "outputs": [],
1201
+ "source": [
1202
+ "device = \"cuda\"\n",
1203
+ "\n",
1204
+ "model_path = \"/path/to/model/save.pth\"\n",
1205
+ "model = ComplexResNet18().to(device)\n",
1206
+ "model.load_state_dict(torch.load(model_path, map_location=device))\n",
1207
+ "model.eval()"
1208
+ ]
1209
+ },
1210
+ {
1211
+ "cell_type": "code",
1212
+ "execution_count": null,
1213
+ "id": "a303080e",
1214
+ "metadata": {},
1215
+ "outputs": [],
1216
+ "source": [
1217
+ "# Load the pre-trained model for evaluation\n",
1218
+ "\n",
1219
+ "full_dataset = ConcatDataset([\n",
1220
+ " WidebandSignalDataset(signal_ids=train, return_snrs=True),\n",
1221
+ " WidebandSignalDataset(signal_ids=validation, return_snrs=True),\n",
1222
+ " WidebandSignalDataset(signal_ids=test, return_snrs=True)\n",
1223
+ "])\n",
1224
+ "full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)"
1225
+ ]
1226
+ },
1227
+ {
1228
+ "cell_type": "markdown",
1229
+ "id": "ad326f1d",
1230
+ "metadata": {},
1231
+ "source": [
1232
+ "### Function initialization"
1233
+ ]
1234
+ },
1235
+ {
1236
+ "cell_type": "code",
1237
+ "execution_count": null,
1238
+ "id": "00d0228c",
1239
+ "metadata": {},
1240
+ "outputs": [],
1241
+ "source": [
1242
+ "def expand_true(array, distance=1):\n",
1243
+ " # Create kernel of appropriate size\n",
1244
+ " kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
1245
+ " array = array.unsqueeze(1).float() # Add channel dimension\n",
1246
+ " result = F.conv1d(array, kernel, padding=distance)\n",
1247
+ " result = result.squeeze(1) # Remove the extra dimension\n",
1248
+ " return result > 0\n",
1249
+ "def reshape_to_2d(data):\n",
1250
+ " return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]\n",
1251
+ "def get_true_groups(tensor, device):\n",
1252
+ " assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
1253
+ " all_groups = []\n",
1254
+ " for i in range(tensor.size(0)):\n",
1255
+ " item = tensor[i]\n",
1256
+ " item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
1257
+ " diffs = item.float().diff()\n",
1258
+ " starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
1259
+ " ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
1260
+ " groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
1261
+ " all_groups.append(groups)\n",
1262
+ " return all_groups\n",
1263
+ "\n",
1264
+ "def calculate_iou(box1, box2):\n",
1265
+ " intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
1266
+ " union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
1267
+ " return intersection / union if union != 0 else 0\n",
1268
+ "\n",
1269
+ "def match_targets(targets, preds):\n",
1270
+ " ious = []\n",
1271
+ " for target in targets:\n",
1272
+ " iou_targets = []\n",
1273
+ " for pred in preds:\n",
1274
+ " iou_targets.append(calculate_iou(target, pred))\n",
1275
+ " ious.append(iou_targets)\n",
1276
+ " cost_matrix = np.array(ious)\n",
1277
+ " row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
1278
+ " return row_ind, col_ind\n",
1279
+ "\n",
1280
+ "def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
1281
+ " ious = [0 for _ in target_boxes]\n",
1282
+ " matching_dict = dict(zip(*matching))\n",
1283
+ " for target_index, target_box in enumerate(target_boxes):\n",
1284
+ " if target_index in matching_dict:\n",
1285
+ " pred_index = matching_dict[target_index]\n",
1286
+ " if pred_index < len(prediction_boxes):\n",
1287
+ " box1 = target_box\n",
1288
+ " box2 = prediction_boxes[pred_index]\n",
1289
+ " ious[target_index] = calculate_iou(box1, box2)\n",
1290
+ " return ious\n",
1291
+ "def model_predictor(signals):\n",
1292
+ " # Convert signals to complex tensors\n",
1293
+ " if signals.dtype != torch.complex64 and signals.dtype != torch.complex128:\n",
1294
+ " signals = signals.type(torch.complex64)\n",
1295
+ " # Reshape the input signals to the expected shape\n",
1296
+ " signals = reshape_to_2d(signals)\n",
1297
+ " signals = signals.to(device)\n",
1298
+ " # Use the already loaded model and apply thresholding\n",
1299
+ " with torch.no_grad():\n",
1300
+ " outputs = model(signals)\n",
1301
+ " # Handle complex outputs appropriately\n",
1302
+ " real_outputs = outputs.real\n",
1303
+ " imag_outputs = outputs.imag\n",
1304
+ " real_pred = (real_outputs > 0.5)\n",
1305
+ " imag_pred = (imag_outputs > 0.5)\n",
1306
+ " combined_pred = torch.logical_or(real_pred, imag_pred)\n",
1307
+ " return expand_true(combined_pred.float())\n",
1308
+ "\n",
1309
+ "# Complex IoU Implementation\n",
1310
+ "def calculate_complex_iou(box1_real, box1_imag, box2_real, box2_imag):\n",
1311
+ " # Calculate real component intersection\n",
1312
+ " real_intersection = max(0, min(box1_real[1], box2_real[1]) - max(box1_real[0], box2_real[0]))\n",
1313
+ " real_union = max(box1_real[1], box2_real[1]) - min(box1_real[0], box2_real[0])\n",
1314
+ " \n",
1315
+ " # Calculate imaginary component intersection\n",
1316
+ " imag_intersection = max(0, min(box1_imag[1], box2_imag[1]) - max(box1_imag[0], box2_imag[0]))\n",
1317
+ " imag_union = max(box1_imag[1], box2_imag[1]) - min(box1_imag[0], box2_imag[0])\n",
1318
+ " \n",
1319
+ " # Combine intersections and unions\n",
1320
+ " total_intersection = real_intersection + imag_intersection\n",
1321
+ " total_union = real_union + imag_union\n",
1322
+ " \n",
1323
+ " # Return IoU\n",
1324
+ " return total_intersection / total_union if total_union != 0 else 0\n",
1325
+ "\n",
1326
+ "def match_complex_targets(targets_real, targets_imag, preds_real, preds_imag):\n",
1327
+ " ious = []\n",
1328
+ " for target_real, target_imag in zip(targets_real, targets_imag):\n",
1329
+ " iou_targets = []\n",
1330
+ " for pred_real, pred_imag in zip(preds_real, preds_imag):\n",
1331
+ " iou_targets.append(calculate_complex_iou(target_real, target_imag, pred_real, pred_imag))\n",
1332
+ " ious.append(iou_targets)\n",
1333
+ " cost_matrix = np.array(ious)\n",
1334
+ " row_ind, col_ind = linear_sum_assignment(-cost_matrix)\n",
1335
+ " return row_ind, col_ind\n",
1336
+ "\n",
1337
+ "def calculate_matched_complex_ious(target_boxes_real, target_boxes_imag, \n",
1338
+ " prediction_boxes_real, prediction_boxes_imag, matching):\n",
1339
+ " ious = [0 for _ in target_boxes_real]\n",
1340
+ " matching_dict = dict(zip(*matching))\n",
1341
+ " for target_index, (target_box_real, target_box_imag) in enumerate(zip(target_boxes_real, target_boxes_imag)):\n",
1342
+ " if target_index in matching_dict:\n",
1343
+ " pred_index = matching_dict[target_index]\n",
1344
+ " if pred_index < len(prediction_boxes_real):\n",
1345
+ " box1_real, box1_imag = target_box_real, target_box_imag\n",
1346
+ " box2_real, box2_imag = prediction_boxes_real[pred_index], prediction_boxes_imag[pred_index]\n",
1347
+ " ious[target_index] = calculate_complex_iou(box1_real, box1_imag, box2_real, box2_imag)\n",
1348
+ " return ious\n"
1349
+ ]
1350
+ },
1351
+ {
1352
+ "cell_type": "markdown",
1353
+ "id": "c114c7a2",
1354
+ "metadata": {},
1355
+ "source": [
1356
+ "### Evaluate function"
1357
+ ]
1358
+ },
1359
+ {
1360
+ "cell_type": "code",
1361
+ "execution_count": null,
1362
+ "id": "41f12e83",
1363
+ "metadata": {},
1364
+ "outputs": [],
1365
+ "source": [
1366
+ "def evaluate(predictor, data_loader, device=\"cuda\"):\n",
1367
+ " iou_thresholds = [0.5, 0.7, 0.9]\n",
1368
+ " snr_metrics = defaultdict(lambda: {\n",
1369
+ " \"iou_sum\": 0.0,\n",
1370
+ " \"iou_count\": 0,\n",
1371
+ " \"recall_counts\": defaultdict(int),\n",
1372
+ " \"total_samples\": defaultdict(int),\n",
1373
+ " \"correct_pixels\": 0,\n",
1374
+ " \"total_pixels\": 0\n",
1375
+ " })\n",
1376
+ " total_iou_sum, total_iou_count = 0.0, 0\n",
1377
+ " total_correct_pixels, total_total_pixels = 0, 0\n",
1378
+ " total_recall_counts = defaultdict(int)\n",
1379
+ " total_samples = defaultdict(int)\n",
1380
+ "\n",
1381
+ " for batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
1382
+ " if len(batch) == 3:\n",
1383
+ " inputs, masks, snrs_in_batch = batch\n",
1384
+ " else:\n",
1385
+ " inputs, masks = batch\n",
1386
+ " snrs_in_batch = [0] * len(inputs) # Default SNR if not provided\n",
1387
+ "\n",
1388
+ " inputs = inputs.to(device)\n",
1389
+ " masks = masks.to(device)\n",
1390
+ " outputs = predictor(inputs)\n",
1391
+ "\n",
1392
+ " for i in range(len(inputs)):\n",
1393
+ " mask = masks[i]\n",
1394
+ " output = outputs[i]\n",
1395
+ "\n",
1396
+ " # Resize output to match mask shape if necessary\n",
1397
+ " if output.numel() != mask.numel():\n",
1398
+ " output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
1399
+ "\n",
1400
+ " thresholded_output = (output >= 0.5).float()\n",
1401
+ "\n",
1402
+ " correct_pixels = (thresholded_output == mask).sum().item()\n",
1403
+ " total_pixels = mask.numel()\n",
1404
+ " total_correct_pixels += correct_pixels\n",
1405
+ " total_total_pixels += total_pixels\n",
1406
+ "\n",
1407
+ " # Get SNR value and round it to the nearest integer\n",
1408
+ " snr = snrs_in_batch[i]\n",
1409
+ " if isinstance(snr, torch.Tensor):\n",
1410
+ " snr = snr.item()\n",
1411
+ " snr = int(round(snr)) # Round SNR to the nearest integer\n",
1412
+ "\n",
1413
+ " snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
1414
+ " snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
1415
+ "\n",
1416
+ " target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
1417
+ " pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
1418
+ " if not target_boxes or not pred_boxes:\n",
1419
+ " continue\n",
1420
+ " matching = match_targets(target_boxes, pred_boxes)\n",
1421
+ " matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
1422
+ "\n",
1423
+ " snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
1424
+ " snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
1425
+ " total_iou_sum += sum(matched_ious)\n",
1426
+ " total_iou_count += len(matched_ious)\n",
1427
+ "\n",
1428
+ " for th in iou_thresholds:\n",
1429
+ " true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
1430
+ " snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
1431
+ " snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
1432
+ " total_recall_counts[th] += true_positives\n",
1433
+ " total_samples[th] += len(target_boxes)\n",
1434
+ "\n",
1435
+ " # Calculate overall metrics\n",
1436
+ " overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
1437
+ " overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
1438
+ " overall_recall = {\n",
1439
+ " th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0\n",
1440
+ " for th in iou_thresholds\n",
1441
+ " }\n",
1442
+ "\n",
1443
+ " # Print overall results\n",
1444
+ " print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
1445
+ " print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
1446
+ " for th in iou_thresholds:\n",
1447
+ " print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
1448
+ "\n",
1449
+ " # Print per-SNR results\n",
1450
+ " for snr in sorted(snr_metrics.keys()):\n",
1451
+ " metrics = snr_metrics[snr]\n",
1452
+ " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
1453
+ " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
1454
+ " print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
1455
+ " print(f\" IoU: {snr_iou:.4f}\")\n",
1456
+ " for th in iou_thresholds:\n",
1457
+ " recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
1458
+ " print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
1459
+ "\n",
1460
+ " return snr_metrics\n"
1461
+ ]
1462
+ },
1463
+ {
1464
+ "cell_type": "code",
1465
+ "execution_count": null,
1466
+ "id": "0d2fd13f",
1467
+ "metadata": {
1468
+ "scrolled": false
1469
+ },
1470
+ "outputs": [],
1471
+ "source": [
1472
+ "# Run evaluation on the full dataset\n",
1473
+ "snr_metrics = evaluate(model_predictor, full_loader, device=device)"
1474
+ ]
1475
+ },
1476
+ {
1477
+ "cell_type": "markdown",
1478
+ "id": "07eade04",
1479
+ "metadata": {},
1480
+ "source": [
1481
+ "### Save and Plot"
1482
+ ]
1483
+ },
1484
+ {
1485
+ "cell_type": "code",
1486
+ "execution_count": null,
1487
+ "id": "bc84b73a",
1488
+ "metadata": {},
1489
+ "outputs": [],
1490
+ "source": [
1491
+ "import os\n",
1492
+ "import json\n",
1493
+ "import matplotlib.pyplot as plt\n",
1494
+ "\n",
1495
+ "def save_results_and_plot(snr_metrics, save_path):\n",
1496
+ " \"\"\"\n",
1497
+ " Saves evaluation results to a JSON file and generates plots for Accuracy, IoU, and Recall vs. SNR.\n",
1498
+ " Sets x-axis limits to range from -9 dB to 12 dB to eliminate blank space on the right.\n",
1499
+ "\n",
1500
+ " Args:\n",
1501
+ " snr_metrics (dict): The evaluation results obtained from the evaluate function.\n",
1502
+ " save_path (str): The directory path where results and plots will be saved.\n",
1503
+ "\n",
1504
+ " Outputs:\n",
1505
+ " - evaluation_results.json\n",
1506
+ " - accuracy_vs_snr.png and .svg\n",
1507
+ " - iou_vs_snr.png and .svg\n",
1508
+ " - recall_vs_snr.png and .svg\n",
1509
+ " \"\"\"\n",
1510
+ " # Ensure the directory exists\n",
1511
+ " os.makedirs(save_path, exist_ok=True)\n",
1512
+ " \n",
1513
+ " # Extract data from snr_metrics\n",
1514
+ " snr_list = sorted(snr_metrics.keys())\n",
1515
+ " accuracy_list = []\n",
1516
+ " iou_list = []\n",
1517
+ " recall_05 = []\n",
1518
+ " recall_07 = []\n",
1519
+ " recall_09 = []\n",
1520
+ " \n",
1521
+ " # Prepare data for JSON serialization\n",
1522
+ " json_data = {}\n",
1523
+ " \n",
1524
+ " for snr in snr_list:\n",
1525
+ " metrics = snr_metrics[snr]\n",
1526
+ " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
1527
+ " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
1528
+ " recall_at_05 = metrics[\"recall_counts\"][0.5] / metrics[\"total_samples\"][0.5] if metrics[\"total_samples\"][0.5] > 0 else 0\n",
1529
+ " recall_at_07 = metrics[\"recall_counts\"][0.7] / metrics[\"total_samples\"][0.7] if metrics[\"total_samples\"][0.7] > 0 else 0\n",
1530
+ " recall_at_09 = metrics[\"recall_counts\"][0.9] / metrics[\"total_samples\"][0.9] if metrics[\"total_samples\"][0.9] > 0 else 0\n",
1531
+ "\n",
1532
+ " # Append to lists for plotting\n",
1533
+ " accuracy_list.append(snr_accuracy)\n",
1534
+ " iou_list.append(snr_iou)\n",
1535
+ " recall_05.append(recall_at_05)\n",
1536
+ " recall_07.append(recall_at_07)\n",
1537
+ " recall_09.append(recall_at_09)\n",
1538
+ "\n",
1539
+ " # Prepare data for JSON\n",
1540
+ " json_data[snr] = {\n",
1541
+ " \"accuracy\": snr_accuracy,\n",
1542
+ " \"iou\": snr_iou,\n",
1543
+ " \"recall\": {\n",
1544
+ " \"0.5\": recall_at_05,\n",
1545
+ " \"0.7\": recall_at_07,\n",
1546
+ " \"0.9\": recall_at_09,\n",
1547
+ " }\n",
1548
+ " }\n",
1549
+ " \n",
1550
+ " # Save json_data to JSON file\n",
1551
+ " json_file_path = os.path.join(save_path, 'evaluation_results.json')\n",
1552
+ " with open(json_file_path, 'w') as json_file:\n",
1553
+ " json.dump(json_data, json_file, indent=4)\n",
1554
+ " \n",
1555
+ " # Plot Accuracy vs. SNR\n",
1556
+ " plt.figure(figsize=(10, 6))\n",
1557
+ " plt.plot(snr_list, accuracy_list, marker='o', label='Accuracy')\n",
1558
+ " plt.title('Accuracy vs. SNR')\n",
1559
+ " plt.xlabel('SNR (dB)')\n",
1560
+ " plt.ylabel('Accuracy (%)')\n",
1561
+ " plt.grid(True)\n",
1562
+ " plt.legend()\n",
1563
+ " \n",
1564
+ " # Set x-axis limits\n",
1565
+ " #plt.xlim(-9, 12)\n",
1566
+ " plt.xlim(-16, 16)\n",
1567
+ " # Save the plot\n",
1568
+ " accuracy_png_path = os.path.join(save_path, 'accuracy_vs_snr.png')\n",
1569
+ " accuracy_svg_path = os.path.join(save_path, 'accuracy_vs_snr.svg')\n",
1570
+ " plt.savefig(accuracy_png_path, format='png', bbox_inches='tight')\n",
1571
+ " plt.savefig(accuracy_svg_path, format='svg', bbox_inches='tight')\n",
1572
+ " \n",
1573
+ " plt.show()\n",
1574
+ " plt.close()\n",
1575
+ " \n",
1576
+ " # Plot IoU vs. SNR\n",
1577
+ " plt.figure(figsize=(10, 6))\n",
1578
+ " plt.plot(snr_list, iou_list, marker='o', color='orange', label='IoU')\n",
1579
+ " plt.title('IoU vs. SNR')\n",
1580
+ " plt.xlabel('SNR (dB)')\n",
1581
+ " plt.ylabel('IoU')\n",
1582
+ " plt.grid(True)\n",
1583
+ " plt.legend()\n",
1584
+ " \n",
1585
+ " # Set x-axis limits\n",
1586
+ " #plt.xlim(-9, 12)\n",
1587
+ " plt.xlim(-16, 16)\n",
1588
+ " # Save the plot\n",
1589
+ " iou_png_path = os.path.join(save_path, 'iou_vs_snr.png')\n",
1590
+ " iou_svg_path = os.path.join(save_path, 'iou_vs_snr.svg')\n",
1591
+ " plt.savefig(iou_png_path, format='png', bbox_inches='tight')\n",
1592
+ " plt.savefig(iou_svg_path, format='svg', bbox_inches='tight')\n",
1593
+ " \n",
1594
+ " plt.show()\n",
1595
+ " plt.close()\n",
1596
+ " \n",
1597
+ " # Plot Recall at Different IoU Thresholds vs. SNR\n",
1598
+ " plt.figure(figsize=(10, 6))\n",
1599
+ " plt.plot(snr_list, recall_05, marker='o', label='Recall @ IoU 0.5')\n",
1600
+ " plt.plot(snr_list, recall_07, marker='s', label='Recall @ IoU 0.7')\n",
1601
+ " plt.plot(snr_list, recall_09, marker='^', label='Recall @ IoU 0.9')\n",
1602
+ " plt.title('Recall at Different IoU Thresholds vs. SNR')\n",
1603
+ " plt.xlabel('SNR (dB)')\n",
1604
+ " plt.ylabel('Recall')\n",
1605
+ " plt.grid(True)\n",
1606
+ " plt.legend()\n",
1607
+ " \n",
1608
+ " # Set x-axis limits\n",
1609
+ " plt.xlim(-9, 12)\n",
1610
+ " \n",
1611
+ " # Save the plot\n",
1612
+ " recall_png_path = os.path.join(save_path, 'recall_vs_snr.png')\n",
1613
+ " recall_svg_path = os.path.join(save_path, 'recall_vs_snr.svg')\n",
1614
+ " plt.savefig(recall_png_path, format='png', bbox_inches='tight')\n",
1615
+ " plt.savefig(recall_svg_path, format='svg', bbox_inches='tight')\n",
1616
+ " \n",
1617
+ " plt.show()\n",
1618
+ " plt.close()\n"
1619
+ ]
1620
+ },
1621
+ {
1622
+ "cell_type": "code",
1623
+ "execution_count": null,
1624
+ "id": "1974e70d",
1625
+ "metadata": {
1626
+ "scrolled": false
1627
+ },
1628
+ "outputs": [],
1629
+ "source": [
1630
+ "save_path = 'CMuSeNet_results/OTA'\n",
1631
+ "\n",
1632
+ "# Save results and generate plots\n",
1633
+ "save_results_and_plot(snr_metrics, save_path)"
1634
+ ]
1635
+ }
1636
+ ],
1637
+ "metadata": {
1638
+ "kernelspec": {
1639
+ "display_name": "Python 3 (ipykernel)",
1640
+ "language": "python",
1641
+ "name": "python3"
1642
+ },
1643
+ "language_info": {
1644
+ "codemirror_mode": {
1645
+ "name": "ipython",
1646
+ "version": 3
1647
+ },
1648
+ "file_extension": ".py",
1649
+ "mimetype": "text/x-python",
1650
+ "name": "python",
1651
+ "nbconvert_exporter": "python",
1652
+ "pygments_lexer": "ipython3",
1653
+ "version": "3.10.9"
1654
+ }
1655
+ },
1656
+ "nbformat": 4,
1657
+ "nbformat_minor": 5
1658
+ }
CMuSeNet_Synthetic.ipynb ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b5007b71",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Initialization"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "3e6b1226",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "### Initialization block\n",
19
+ "from pathlib import Path\n",
20
+ "import numpy as np\n",
21
+ "import json\n",
22
+ "import torch\n",
23
+ "import numpy as np\n",
24
+ "from tqdm import tqdm\n",
25
+ "import math\n",
26
+ "from torch.utils.data import DataLoader, TensorDataset\n",
27
+ "\n",
28
+ "STFT_LENGTH = 16 * 1024\n",
29
+ "DATA_DIR = Path(\"dataset/\")\n",
30
+ "SAMPLE_RATE = 20e6\n",
31
+ "MODULATIONS = [\"QPSK\", \"BPSK\", \"8-PSK\", \"8-QAM\", \"16-QAM\", \"GMSK\", \"2-FSK\"]\n",
32
+ "MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}\n",
33
+ "NUMBER_OF_MODULATIONS = len(MODULATIONS)\n",
34
+ "\n",
35
+ "def load_data(snr, name, load_metadata_only=False):\n",
36
+ " if not load_metadata_only:\n",
37
+ " with open(DATA_DIR/str(snr)/str(name)/\"data.dat\", \"rb\") as f:\n",
38
+ " signal = np.fromfile(f, dtype=np.complex128)\n",
39
+ " else:\n",
40
+ " signal = None\n",
41
+ " with open(DATA_DIR/str(snr)/str(name)/\"meta-data.json\") as f:\n",
42
+ " meta = json.load(f)\n",
43
+ " if type(meta) == dict:\n",
44
+ " meta = [meta]\n",
45
+ " return signal, meta\n",
46
+ "\n",
47
+ " \n",
48
+ "def _get_all_numbered_dirs(root_dir):\n",
49
+ " dirs = []\n",
50
+ " for directory in root_dir.iterdir():\n",
51
+ " dirs.append(int(directory.name))\n",
52
+ " dirs.sort()\n",
53
+ " return dirs\n",
54
+ "\n",
55
+ "def get_signals(snr):\n",
56
+ " return _get_all_numbered_dirs(Path(DATA_DIR)/str(snr))\n",
57
+ "\n",
58
+ "\n",
59
+ "def get_snrs(root_dir=DATA_DIR):\n",
60
+ " return _get_all_numbered_dirs(root_dir)\n",
61
+ " \n",
62
+ " \n",
63
+ "def process_metadata(metadata):\n",
64
+ " scaled_metadata = [\n",
65
+ " {\n",
66
+ " \"position\": (SAMPLE_RATE/2 + i['fc'], i['bw']),\n",
67
+ " \"mod\": i[\"mod\"]\n",
68
+ " }\n",
69
+ " for i in metadata\n",
70
+ " ]\n",
71
+ " return scaled_metadata\n",
72
+ "\n",
73
+ "\n",
74
+ "def process_signal(signal):\n",
75
+ " signal = signal[:STFT_LENGTH]\n",
76
+ "\n",
77
+ " signal = np.fft.fft(signal)\n",
78
+ " signal = np.fft.fftshift(signal)\n",
79
+ " signal /= np.max(np.abs(signal))\n",
80
+ " \n",
81
+ " #return np.expand_dims(signal, axis=0)\n",
82
+ " return signal"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "440b802c",
88
+ "metadata": {},
89
+ "source": [
90
+ "### Data Loading"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "31bc3770",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "MASK_SIZE = int(STFT_LENGTH)\n",
101
+ "\n",
102
+ "class WidebandSignalDataset(torch.utils.data.Dataset):\n",
103
+ " def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snr=False):\n",
104
+ " self.mask_size = mask_size\n",
105
+ " self.signal_ids = signal_ids\n",
106
+ " self.return_snr = return_snr # New parameter to control SNR return\n",
107
+ " loaded_data = []\n",
108
+ " for snr, signal_id in tqdm(self.signal_ids):\n",
109
+ " signal, masks = self.process_signal(snr, signal_id)\n",
110
+ " loaded_data.append((signal, masks))\n",
111
+ " self.loaded_data = loaded_data\n",
112
+ "\n",
113
+ " def __len__(self):\n",
114
+ " return len(self.signal_ids)\n",
115
+ "\n",
116
+ " def __getitem__(self, index):\n",
117
+ " signal, masks = self.loaded_data[index]\n",
118
+ " if self.return_snr:\n",
119
+ " snr, _ = self.signal_ids[index]\n",
120
+ " return signal, masks, snr # Return SNR during evaluation\n",
121
+ " else:\n",
122
+ " return signal, masks # Return only signal and masks during training\n",
123
+ "\n",
124
+ " def process_signal(self, snr, signal_id):\n",
125
+ " signal, metadata = load_data(snr, signal_id)\n",
126
+ " scaled_metadata = process_metadata(metadata)\n",
127
+ " signal = process_signal(signal)\n",
128
+ " signal = torch.from_numpy(signal)\n",
129
+ " masks = torch.zeros(self.mask_size)\n",
130
+ " scale_ratio = self.mask_size / SAMPLE_RATE\n",
131
+ " for meta in scaled_metadata:\n",
132
+ " f, b = meta['position']\n",
133
+ " x1, x2 = math.floor((f - b / 2) * scale_ratio), math.ceil((f + b / 2) * scale_ratio)\n",
134
+ " masks[x1:x2] = 1\n",
135
+ " return signal.type(torch.complex64), masks.type(torch.FloatTensor)\n",
136
+ "\n",
137
+ "# Train test split 80 - 10 - 10\n",
138
+ "train, test, validation = [], [], [] \n",
139
+ "for snr in get_snrs():\n",
140
+ " signals = get_signals(snr)\n",
141
+ " total_signals = len(signals)\n",
142
+ " for signal in signals:\n",
143
+ " if signal <= 0.8 * total_signals:\n",
144
+ " train.append((snr, signal))\n",
145
+ " elif signal <= 0.9 * total_signals:\n",
146
+ " validation.append((snr, signal))\n",
147
+ " else:\n",
148
+ " test.append((snr, signal))\n",
149
+ " \n",
150
+ "print(\"Train\", len(train))\n",
151
+ "print(\"Validation\", len(validation))\n",
152
+ "print(\"Test\", len(test))\n",
153
+ "\n",
154
+ "train_dataset = WidebandSignalDataset(signal_ids=train)\n",
155
+ "validation_dataset = WidebandSignalDataset(signal_ids=validation)\n",
156
+ "test_dataset = WidebandSignalDataset(signal_ids=test)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "637ae774",
162
+ "metadata": {},
163
+ "source": [
164
+ "### Batch Loading"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "a9af2450",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "batch_size = 64 # Updated batch size\n",
175
+ "\n",
176
+ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
177
+ "valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)\n",
178
+ "\n",
179
+ "print(\"Train labels shape:\", len(train_dataset))\n",
180
+ "print(\"Validation labels shape:\", len(validation_dataset))"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "markdown",
185
+ "id": "9a8e09e4",
186
+ "metadata": {},
187
+ "source": [
188
+ "### Early Stop"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "24f79a24",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "import os\n",
199
+ "\n",
200
+ "class EarlyStopping:\n",
201
+ " def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./models/CMuSeNet'):\n",
202
+ " self.patience = patience\n",
203
+ " self.verbose = verbose\n",
204
+ " self.delta = delta\n",
205
+ " self.counter = 0\n",
206
+ " self.best_score = None\n",
207
+ " self.early_stop = False\n",
208
+ " self.val_loss_min = float('inf')\n",
209
+ " self.best_model = None\n",
210
+ " self.save_path = save_path\n",
211
+ " os.makedirs(save_path, exist_ok=True)\n",
212
+ " \n",
213
+ " def __call__(self, val_loss, model):\n",
214
+ " score = -val_loss\n",
215
+ "\n",
216
+ " if self.best_score is None:\n",
217
+ " self.best_score = score\n",
218
+ " self.save_checkpoint(val_loss, model)\n",
219
+ " elif score < self.best_score + self.delta:\n",
220
+ " self.counter += 1\n",
221
+ " if self.verbose:\n",
222
+ " print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
223
+ " if self.counter >= self.patience:\n",
224
+ " self.early_stop = True\n",
225
+ " else:\n",
226
+ " self.best_score = score\n",
227
+ " self.save_checkpoint(val_loss, model)\n",
228
+ " self.counter = 0\n",
229
+ "\n",
230
+ " def save_checkpoint(self, val_loss, model):\n",
231
+ " if self.verbose:\n",
232
+ " print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n",
233
+ " self.val_loss_min = val_loss\n",
234
+ " self.best_model = model.state_dict()\n",
235
+ " save_path = os.path.join(self.save_path, 'best_model.pth')\n",
236
+ " torch.save(self.best_model, save_path)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "id": "6c3fda74",
242
+ "metadata": {},
243
+ "source": [
244
+ "### Reshape"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "id": "5fcf91db",
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "import torch.nn as nn\n",
255
+ "import complexPyTorch.complexLayers as cplx\n",
256
+ "import torch.nn.functional as F\n",
257
+ "import torch\n",
258
+ "\n",
259
+ "def reshape_to_2d(data):\n",
260
+ " return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "id": "b7d7562c",
266
+ "metadata": {},
267
+ "source": [
268
+ "### Complex IoU"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "id": "7218c3f3",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "def calculate_iou(pred, target, threshold=0.5):\n",
279
+ " real_pred = (pred.real > threshold).float()\n",
280
+ " imag_pred = (pred.imag > threshold).float()\n",
281
+ " \n",
282
+ " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
283
+ " \n",
284
+ " intersection = (combined_pred * target).sum(dim=1)\n",
285
+ " union = (combined_pred + target).sum(dim=1) - intersection\n",
286
+ " iou = (intersection / union).mean().item()\n",
287
+ " return iou"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "id": "64f4063c",
293
+ "metadata": {},
294
+ "source": [
295
+ "### Training"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "66825110",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "import time\n",
306
+ "\n",
307
+ "def validate_model(model, valid_loader, criterion):\n",
308
+ " model.eval()\n",
309
+ " running_loss = 0.0\n",
310
+ " iou_scores = []\n",
311
+ " total_correct = 0\n",
312
+ " total_samples = 0\n",
313
+ "\n",
314
+ " with torch.no_grad():\n",
315
+ " for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n",
316
+ " inputs = reshape_to_2d(inputs).to(device)\n",
317
+ " masks = masks.to(device)\n",
318
+ " outputs = model(inputs)\n",
319
+ " loss = criterion(outputs, masks)\n",
320
+ " running_loss += loss.item()\n",
321
+ "\n",
322
+ " # Calculate IoU\n",
323
+ " iou = calculate_iou(outputs, masks, threshold=0.5)\n",
324
+ " iou_scores.append(iou)\n",
325
+ " \n",
326
+ " # Calculate accuracy\n",
327
+ " preds = ((outputs.real > 0.5) & (outputs.imag > 0.5)).float()\n",
328
+ " correct = (preds == masks).float().sum()\n",
329
+ " total_correct += correct.item()\n",
330
+ " total_samples += masks.numel()\n",
331
+ "\n",
332
+ " val_loss = running_loss / len(valid_loader)\n",
333
+ " mean_iou = sum(iou_scores) / len(iou_scores)\n",
334
+ " accuracy = total_correct / total_samples * 100\n",
335
+ "\n",
336
+ " print(f'Validation Loss: {val_loss:.6f}')\n",
337
+ " print(f'Validation Accuracy: {accuracy:.2f}%')\n",
338
+ "\n",
339
+ " return val_loss, accuracy\n",
340
+ "\n",
341
+ "def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.00001], num_epochs=50, patience=3):\n",
342
+ " train_losses = []\n",
343
+ " val_losses = []\n",
344
+ " val_accuracies = []\n",
345
+ " epoch_durations = []\n",
346
+ " \n",
347
+ " current_lr = initial_lr\n",
348
+ " for lr in lr_steps:\n",
349
+ " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
350
+ " early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n",
351
+ " print(\"Current learning rate: \", lr)\n",
352
+ " for epoch in range(num_epochs):\n",
353
+ " epoch_start_time = time.time()\n",
354
+ " \n",
355
+ " model.train()\n",
356
+ " running_loss = 0.0\n",
357
+ " for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n",
358
+ " inputs = reshape_to_2d(inputs).to(device)\n",
359
+ " masks = masks.to(device)\n",
360
+ " outputs = model(inputs)\n",
361
+ " loss = criterion(outputs, masks)\n",
362
+ "\n",
363
+ " optimizer.zero_grad()\n",
364
+ " loss.backward()\n",
365
+ " optimizer.step()\n",
366
+ "\n",
367
+ " running_loss += loss.item()\n",
368
+ "\n",
369
+ " epoch_loss = running_loss / len(train_loader)\n",
370
+ " train_losses.append(epoch_loss)\n",
371
+ " print(f\"Training Loss: {epoch_loss:.6f}\")\n",
372
+ "\n",
373
+ " val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n",
374
+ " val_losses.append(val_loss)\n",
375
+ " val_accuracies.append(val_accuracy)\n",
376
+ " early_stopping(val_loss, model)\n",
377
+ "\n",
378
+ " if early_stopping.early_stop:\n",
379
+ " print(\"Early stopping triggered\")\n",
380
+ " break\n",
381
+ "\n",
382
+ " epoch_duration = time.time() - epoch_start_time\n",
383
+ " epoch_durations.append(epoch_duration)\n",
384
+ " if early_stopping.best_model is not None:\n",
385
+ " print(f\"Loading best model from lr {lr}\")\n",
386
+ " model.load_state_dict(early_stopping.best_model)\n",
387
+ " \n",
388
+ " print(\"Training completed.\")\n",
389
+ " print(\"Epoch durations:\", epoch_durations)\n",
390
+ " return model, train_losses, val_losses, val_accuracies, epoch_durations"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "id": "0b80cb51",
396
+ "metadata": {},
397
+ "source": [
398
+ "### ResNet-18"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "2d208cb9",
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "import torch\n",
409
+ "import torch.nn as nn\n",
410
+ "import complexPyTorch.complexLayers as cplx\n",
411
+ "from typing import Optional, Callable, Type, Union, List\n",
412
+ "import torch.nn.functional as F\n",
413
+ "from torch import Tensor\n",
414
+ "\n",
415
+ "def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
416
+ " \"\"\"3x3 convolution with padding\"\"\"\n",
417
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
418
+ "\n",
419
+ "def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n",
420
+ " \"\"\"1x1 convolution\"\"\"\n",
421
+ " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
422
+ "\n",
423
+ "class BasicBlock(nn.Module):\n",
424
+ " expansion = 1\n",
425
+ "\n",
426
+ " def __init__(\n",
427
+ " self,\n",
428
+ " inplanes: int,\n",
429
+ " planes: int,\n",
430
+ " stride: int = 1,\n",
431
+ " downsample: Optional[nn.Module] = None,\n",
432
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
433
+ " ) -> None:\n",
434
+ " super(BasicBlock, self).__init__()\n",
435
+ " self.conv1 = conv3x3(inplanes, planes, stride)\n",
436
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
437
+ " self.relu = cplx.ComplexReLU()\n",
438
+ " self.conv2 = conv3x3(planes, planes)\n",
439
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
440
+ " self.downsample = downsample\n",
441
+ " self.stride = stride\n",
442
+ "\n",
443
+ " def forward(self, x: Tensor) -> Tensor:\n",
444
+ " identity = x\n",
445
+ "\n",
446
+ " out = self.conv1(x)\n",
447
+ " out = self.bn1(out)\n",
448
+ " out = self.relu(out)\n",
449
+ "\n",
450
+ " out = self.conv2(out)\n",
451
+ " out = self.bn2(out)\n",
452
+ "\n",
453
+ " if self.downsample is not None:\n",
454
+ " identity = self.downsample(x)\n",
455
+ "\n",
456
+ " out += identity\n",
457
+ " out = self.relu(out)\n",
458
+ "\n",
459
+ " return out\n",
460
+ "\n",
461
+ "class Bottleneck(nn.Module):\n",
462
+ " expansion = 4\n",
463
+ "\n",
464
+ " def __init__(\n",
465
+ " self,\n",
466
+ " inplanes: int,\n",
467
+ " planes: int,\n",
468
+ " stride: int = 1,\n",
469
+ " downsample: Optional[nn.Module] = None,\n",
470
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
471
+ " ) -> None:\n",
472
+ " super(Bottleneck, self).__init__()\n",
473
+ " self.conv1 = conv1x1(inplanes, planes)\n",
474
+ " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n",
475
+ " self.conv2 = conv3x3(planes, planes, stride)\n",
476
+ " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n",
477
+ " self.conv3 = conv1x1(planes, planes * self.expansion)\n",
478
+ " self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n",
479
+ " self.relu = cplx.ComplexReLU()\n",
480
+ " self.downsample = downsample\n",
481
+ " self.stride = stride\n",
482
+ "\n",
483
+ " def forward(self, x: Tensor) -> Tensor:\n",
484
+ " identity = x\n",
485
+ "\n",
486
+ " out = self.conv1(x)\n",
487
+ " out = self.bn1(out)\n",
488
+ " out = self.relu(out)\n",
489
+ "\n",
490
+ " out = self.conv2(out)\n",
491
+ " out = self.bn2(out)\n",
492
+ " out = self.relu(out)\n",
493
+ "\n",
494
+ " out = self.conv3(out)\n",
495
+ " out = self.bn3(out)\n",
496
+ "\n",
497
+ " if self.downsample is not None:\n",
498
+ " identity = self.downsample(x)\n",
499
+ "\n",
500
+ " out += identity\n",
501
+ " out = self.relu(out)\n",
502
+ "\n",
503
+ " return out\n",
504
+ "\n",
505
+ "class ComplexResNet(nn.Module):\n",
506
+ " def __init__(\n",
507
+ " self,\n",
508
+ " block: Type[Union[BasicBlock, Bottleneck]],\n",
509
+ " layers: List[int],\n",
510
+ " num_classes: int = STFT_LENGTH,\n",
511
+ " zero_init_residual: bool = False,\n",
512
+ " groups: int = 1,\n",
513
+ " width_per_group: int = 64,\n",
514
+ " norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
515
+ " ) -> None:\n",
516
+ " super(ComplexResNet, self).__init__()\n",
517
+ " if norm_layer is None:\n",
518
+ " norm_layer = cplx.ComplexBatchNorm2d\n",
519
+ " self._norm_layer = norm_layer\n",
520
+ "\n",
521
+ " self.inplanes = 64\n",
522
+ " self.dilation = 1\n",
523
+ "\n",
524
+ " self.groups = groups\n",
525
+ " self.base_width = width_per_group\n",
526
+ " self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n",
527
+ " self.bn1 = norm_layer(self.inplanes)\n",
528
+ " self.relu = cplx.ComplexReLU()\n",
529
+ " self.layer1 = self._make_layer(block, 64, layers[0])\n",
530
+ " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
531
+ " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
532
+ " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
533
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
534
+ " self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n",
535
+ " self.sigmoid = cplx.ComplexSigmoid()\n",
536
+ "\n",
537
+ " def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n",
538
+ " norm_layer = self._norm_layer\n",
539
+ " downsample = None\n",
540
+ " if stride != 1 or self.inplanes != planes * block.expansion:\n",
541
+ " downsample = nn.Sequential(\n",
542
+ " conv1x1(self.inplanes, planes * block.expansion, stride),\n",
543
+ " norm_layer(planes * block.expansion),\n",
544
+ " )\n",
545
+ "\n",
546
+ " layers = []\n",
547
+ " layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n",
548
+ " self.inplanes = planes * block.expansion\n",
549
+ " for _ in range(1, blocks):\n",
550
+ " layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n",
551
+ "\n",
552
+ " return nn.Sequential(*layers)\n",
553
+ "\n",
554
+ " def _forward_impl(self, x: Tensor) -> Tensor:\n",
555
+ " x = self.conv1(x)\n",
556
+ " x = self.bn1(x)\n",
557
+ " x = self.relu(x)\n",
558
+ "\n",
559
+ " x = self.layer1(x)\n",
560
+ " x = self.layer2(x)\n",
561
+ " x = self.layer3(x)\n",
562
+ " x = self.layer4(x)\n",
563
+ "\n",
564
+ " x = self.avgpool(x)\n",
565
+ " x = torch.flatten(x, 1)\n",
566
+ " x = self.fc(x)\n",
567
+ " x = self.sigmoid(x)\n",
568
+ " return x\n",
569
+ "\n",
570
+ " def forward(self, x: Tensor) -> Tensor:\n",
571
+ " return self._forward_impl(x)\n",
572
+ "\n",
573
+ "def ComplexResNet18():\n",
574
+ " return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n",
575
+ "\n",
576
+ "# Create the model instance\n",
577
+ "model = ComplexResNet18()\n",
578
+ "print(model)\n"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "markdown",
583
+ "id": "e4bc1b5d",
584
+ "metadata": {},
585
+ "source": [
586
+ "### Complex focal Loss"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "id": "61c29429",
593
+ "metadata": {},
594
+ "outputs": [],
595
+ "source": [
596
+ "class ComplexFocalLoss(nn.Module):\n",
597
+ " def __init__(self, alpha=1, gamma=2, reduction='mean'):\n",
598
+ " super(ComplexFocalLoss, self).__init__()\n",
599
+ " self.alpha = alpha\n",
600
+ " self.gamma = gamma\n",
601
+ " self.reduction = reduction\n",
602
+ "\n",
603
+ " def forward(self, inputs, targets):\n",
604
+ " real_inputs = inputs.real\n",
605
+ " imag_inputs = inputs.imag\n",
606
+ " \n",
607
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n",
608
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n",
609
+ " \n",
610
+ " real_pt = torch.exp(-real_BCE_loss)\n",
611
+ " imag_pt = torch.exp(-imag_BCE_loss)\n",
612
+ " \n",
613
+ " real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n",
614
+ " imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n",
615
+ "\n",
616
+ " if self.reduction == 'mean':\n",
617
+ " return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n",
618
+ " elif self.reduction == 'sum':\n",
619
+ " return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n",
620
+ " else:\n",
621
+ " return real_F_loss + imag_F_loss\n",
622
+ "\n",
623
+ "# Update the IoU calculation to handle complex values\n",
624
+ "def calculate_iou(pred, target, threshold=0.5):\n",
625
+ " real_pred = (pred.real > threshold).float()\n",
626
+ " imag_pred = (pred.imag > threshold).float()\n",
627
+ " \n",
628
+ " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n",
629
+ " \n",
630
+ " intersection = (combined_pred * target).sum(dim=1)\n",
631
+ " union = (combined_pred + target).sum(dim=1) - intersection\n",
632
+ " iou = (intersection / union).mean().item()\n",
633
+ " return iou"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "id": "abb35ba2",
639
+ "metadata": {},
640
+ "source": [
641
+ "### Training with complex focal loss"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "id": "86d7526b",
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "# Initialize and train the CResNet-18 model\n",
652
+ "model = ComplexResNet18().to(device)\n",
653
+ "criterion = ComplexFocalLoss()\n",
654
+ "\n",
655
+ "# Train the model and validate it\n",
656
+ "#0.001, 0.0001, 0.00001, 0.000001\n",
657
+ "model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n",
658
+ "combined_epoch_time = sum(epoch_durations)\n",
659
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "markdown",
664
+ "id": "fd0c9d58",
665
+ "metadata": {},
666
+ "source": [
667
+ "### CVNN RV-BCE and CV-BCE Loss function implementation"
668
+ ]
669
+ },
670
+ {
671
+ "cell_type": "code",
672
+ "execution_count": null,
673
+ "id": "99c736b8",
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": [
677
+ "# RV BCE Loss Function Definition\n",
678
+ "class RealValuedBCELoss(nn.Module):\n",
679
+ " def __init__(self, reduction='mean'):\n",
680
+ " super(RealValuedBCELoss, self).__init__()\n",
681
+ " self.reduction = reduction\n",
682
+ "\n",
683
+ " def forward(self, inputs, targets):\n",
684
+ " # Use only the real part of the complex inputs\n",
685
+ " real_inputs = inputs.real\n",
686
+ " BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
687
+ " return BCE_loss\n",
688
+ "\n",
689
+ " \n",
690
+ "# CV BCE Loss Function Definition\n",
691
+ "class ComplexValuedBCELoss(nn.Module):\n",
692
+ " def __init__(self, reduction='mean'):\n",
693
+ " super(ComplexValuedBCELoss, self).__init__()\n",
694
+ " self.reduction = reduction\n",
695
+ "\n",
696
+ " def forward(self, inputs, targets):\n",
697
+ " real_inputs = inputs.real\n",
698
+ " imag_inputs = inputs.imag\n",
699
+ "\n",
700
+ " # Calculate binary cross-entropy for both real and imaginary parts\n",
701
+ " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n",
702
+ " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n",
703
+ " \n",
704
+ " # Combine the losses (you can adjust the weighting if necessary)\n",
705
+ " combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n",
706
+ " return combined_BCE_loss"
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "markdown",
711
+ "id": "d6930f39",
712
+ "metadata": {},
713
+ "source": [
714
+ "### RV-BCE Training"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": null,
720
+ "id": "9e59d4c9",
721
+ "metadata": {},
722
+ "outputs": [],
723
+ "source": [
724
+ "# Set the criterion for RV BCE\n",
725
+ "criterion = RealValuedBCELoss()\n",
726
+ "\n",
727
+ "# Train the ResNet-18 model with RV BCE\n",
728
+ "device = torch.device('cuda')\n",
729
+ "model = ComplexResNet18().to(device)\n",
730
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
731
+ "\n",
732
+ "# Start training with the previously defined train_model function\n",
733
+ "model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
734
+ " model, train_loader, valid_loader, criterion, \n",
735
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001, 0.00001, 0.000001], num_epochs=50, patience=3\n",
736
+ ")\n"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "markdown",
741
+ "id": "93d19ea7",
742
+ "metadata": {},
743
+ "source": [
744
+ "### CV-BCE Training"
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "id": "2c56d5b4",
751
+ "metadata": {},
752
+ "outputs": [],
753
+ "source": [
754
+ "# Set the criterion for CV BCE\n",
755
+ "criterion = ComplexValuedBCELoss()\n",
756
+ "\n",
757
+ "# Train the ResNet-18 model with CV BCE\n",
758
+ "device = torch.device('cuda')\n",
759
+ "model = ComplexResNet18().to(device)\n",
760
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
761
+ "\n",
762
+ "# Start training with the previously defined train_model function\n",
763
+ "model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n",
764
+ " model, train_loader, valid_loader, criterion, \n",
765
+ " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n",
766
+ ")\n",
767
+ "combined_epoch_time = sum(epoch_durations)\n",
768
+ "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "markdown",
773
+ "id": "f4f6530e",
774
+ "metadata": {},
775
+ "source": [
776
+ "### Plot training result (Accuracy, loss vs epoch)"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "id": "43676a01",
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "import matplotlib.pyplot as plt\n",
787
+ "import json\n",
788
+ "import os\n",
789
+ "\n",
790
+ "# Ensure the directory exists\n",
791
+ "output_dir = 'cvnn_results/segmentation'\n",
792
+ "os.makedirs(output_dir, exist_ok=True)\n",
793
+ "\n",
794
+ "def save_metrics_to_json(train_losses, val_accuracies, epoch_durations, filename):\n",
795
+ " \"\"\"\n",
796
+ " Save the training losses and validation accuracies to a JSON file.\n",
797
+ " \n",
798
+ " Args:\n",
799
+ " train_losses (list): List of training losses.\n",
800
+ " val_accuracies (list): List of validation accuracies.\n",
801
+ " filename (str): The file name for the JSON file.\n",
802
+ " \"\"\"\n",
803
+ " metrics = {\n",
804
+ " \"train_losses\": train_losses,\n",
805
+ " \"val_accuracies\": val_accuracies,\n",
806
+ " \"epoch_durations\": epoch_durations\n",
807
+ " }\n",
808
+ " with open(os.path.join(output_dir, filename), 'w') as f:\n",
809
+ " json.dump(metrics, f)\n",
810
+ "\n",
811
+ "def plot_training_metrics(train_losses, val_accuracies, plot_filename):\n",
812
+ " \"\"\"\n",
813
+ " Plot the training loss and validation accuracy, and mark the epoch where accuracy reaches 99%.\n",
814
+ " \n",
815
+ " Args:\n",
816
+ " train_losses (list): List of training losses.\n",
817
+ " val_accuracies (list): List of validation accuracies.\n",
818
+ " plot_filename (str): The file name for saving the plot as SVG.\n",
819
+ " \"\"\"\n",
820
+ " epochs = range(1, len(train_losses) + 1)\n",
821
+ "\n",
822
+ " plt.figure(figsize=(14, 6))\n",
823
+ "\n",
824
+ " # Plot Training Loss\n",
825
+ " plt.subplot(1, 2, 1)\n",
826
+ " plt.plot(epochs, train_losses, label='Training Loss')\n",
827
+ " plt.xlabel('Epochs')\n",
828
+ " plt.ylabel('Loss')\n",
829
+ " plt.title('Training Loss')\n",
830
+ " plt.legend()\n",
831
+ "\n",
832
+ " # Plot Validation Accuracy\n",
833
+ " plt.subplot(1, 2, 2)\n",
834
+ " plt.plot(epochs, val_accuracies, label='Validation Accuracy')\n",
835
+ " plt.xlabel('Epochs')\n",
836
+ " plt.ylabel('Accuracy (%)')\n",
837
+ " plt.title('Validation Accuracy')\n",
838
+ " plt.legend()\n",
839
+ "\n",
840
+ " # Find the first epoch where validation accuracy reaches or exceeds 99%\n",
841
+ " for i, acc in enumerate(val_accuracies):\n",
842
+ " if acc >= 99:\n",
843
+ " first_99_epoch = i + 1 # Epochs are 1-based\n",
844
+ " plt.axvline(first_99_epoch, color='r', linestyle='--', label=f'99% reached at epoch {first_99_epoch}')\n",
845
+ " break\n",
846
+ "\n",
847
+ " plt.legend()\n",
848
+ " plt.tight_layout()\n",
849
+ "\n",
850
+ " # Save the plot as an SVG file\n",
851
+ " plt.savefig(os.path.join(output_dir, plot_filename), format='svg')\n",
852
+ " plt.show()\n",
853
+ "\n",
854
+ "# Save the metrics to JSON in cvnn_results/segmentation\n",
855
+ "save_metrics_to_json(train_losses, val_accuracies, epoch_durations, 'training_metrics.json')\n",
856
+ "\n",
857
+ "# Plot the metrics and highlight when accuracy reaches 99%, saving the plot as SVG\n",
858
+ "plot_training_metrics(train_losses, val_accuracies, 'training_metrics_plot.svg')"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "markdown",
863
+ "id": "c6f4ea75",
864
+ "metadata": {},
865
+ "source": [
866
+ "### Evaluation "
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": null,
872
+ "id": "a303080e",
873
+ "metadata": {},
874
+ "outputs": [],
875
+ "source": [
876
+ "# Load the pre-trained model for evaluation\n",
877
+ "import torch\n",
878
+ "\n",
879
+ "device = \"cuda\"\n",
880
+ "\n",
881
+ "model_path = \"path/to/the/model\" #Please change this to the model path you trained\n",
882
+ "model = ComplexResNet18().to(device)\n",
883
+ "model.load_state_dict(torch.load(model_path, map_location=device))\n",
884
+ "model.eval()\n"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": null,
890
+ "id": "0590b6ef",
891
+ "metadata": {},
892
+ "outputs": [],
893
+ "source": [
894
+ "import torch\n",
895
+ "from tqdm import tqdm\n",
896
+ "from torch.utils.data import DataLoader\n",
897
+ "import numpy as np\n",
898
+ "\n",
899
+ "# Define thresholds for recall calculation\n",
900
+ "iou_thresholds = [0.5, 0.7, 0.9]\n",
901
+ "\n",
902
+ "# Initialize metrics\n",
903
+ "snr_results = {}\n",
904
+ "total_accuracy = 0.0\n",
905
+ "total_samples = 0\n",
906
+ "iou_scores = {th: 0.0 for th in iou_thresholds}\n",
907
+ "recall_counts = {th: 0 for th in iou_thresholds}\n",
908
+ "BATCH_SIZE = 64\n",
909
+ "# Create DataLoader for the entire dataset\n",
910
+ "full_dataset = WidebandSignalDataset(signal_ids=train + validation + test, return_snr=True)\n",
911
+ "full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "markdown",
916
+ "id": "6db6a18f",
917
+ "metadata": {},
918
+ "source": [
919
+ "### Bounding Box"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "code",
924
+ "execution_count": null,
925
+ "id": "e396c72c",
926
+ "metadata": {},
927
+ "outputs": [],
928
+ "source": [
929
+ "import torch\n",
930
+ "from collections import defaultdict\n",
931
+ "import time\n",
932
+ "from tqdm import tqdm\n",
933
+ "import torch\n",
934
+ "import torch.nn.functional as F\n",
935
+ "from scipy.optimize import linear_sum_assignment\n",
936
+ "\n",
937
+ "def expand_true(array, distance=1):\n",
938
+ " # Create kernel of appropriate size\n",
939
+ " kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n",
940
+ " array = array.unsqueeze(1).float() # Add channel dimension\n",
941
+ " result = F.conv1d(array, kernel, padding=distance)\n",
942
+ " result = result.squeeze(1) # Remove the extra dimension\n",
943
+ " \n",
944
+ " # Convert values greater than 0 to `True`\n",
945
+ " return result > 0\n",
946
+ "\n",
947
+ "# Define supporting functions based on your friend's code\n",
948
+ "def get_true_groups(tensor, device):\n",
949
+ " assert tensor.dim() == 2, 'This function handles 2D tensor only'\n",
950
+ " all_groups = []\n",
951
+ " for i in range(tensor.size(0)):\n",
952
+ " item = tensor[i]\n",
953
+ " item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n",
954
+ " diffs = item.float().diff()\n",
955
+ " starts = (diffs == 1).nonzero(as_tuple=True)[0]\n",
956
+ " ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n",
957
+ " groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n",
958
+ " all_groups.append(groups)\n",
959
+ " return all_groups\n",
960
+ "\n",
961
+ "def get_target_boxes(metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n",
962
+ " scale_ratio = number_of_bins / sample_rate\n",
963
+ " targets = []\n",
964
+ " masks = torch.zeros(number_of_bins)\n",
965
+ " for meta in metadata:\n",
966
+ " f, b = meta['position']\n",
967
+ " x1, x2 = math.floor((f-b/2)*scale_ratio), math.ceil((f+b/2)*scale_ratio)\n",
968
+ " masks[x1:x2] = 1\n",
969
+ " targets.append((x1, x2))\n",
970
+ " return targets, masks\n",
971
+ "\n",
972
+ "def get_target_boxes_batch(batch_metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n",
973
+ " all_targets, all_masks = [], []\n",
974
+ " for metadata in batch_metadata:\n",
975
+ " targets, masks = get_target_boxes(metadata, number_of_bins, sample_rate)\n",
976
+ " all_targets.append(targets)\n",
977
+ " all_masks.append(masks)\n",
978
+ " return all_targets, all_masks\n",
979
+ "\n",
980
+ "def calculate_iou(box1, box2):\n",
981
+ " intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n",
982
+ " union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n",
983
+ " return intersection / union if union != 0 else 0\n",
984
+ "\n",
985
+ "def match_targets(targets, preds):\n",
986
+ " ious = []\n",
987
+ " for target in targets:\n",
988
+ " iou_targets = []\n",
989
+ " for pred in preds:\n",
990
+ " iou_targets.append(calculate_iou(target, pred))\n",
991
+ " ious.append(iou_targets)\n",
992
+ " return linear_sum_assignment(ious, maximize=True)\n",
993
+ "\n",
994
+ "def match_targets_batch(batch_targets, batch_preds):\n",
995
+ " all_assignments = []\n",
996
+ " for targets, preds in zip(batch_targets, batch_preds):\n",
997
+ " all_assignments.append(match_targets(targets, preds))\n",
998
+ " return all_assignments\n",
999
+ "\n",
1000
+ "def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n",
1001
+ " ious = [0 for _ in target_boxes]\n",
1002
+ " matching_dict = dict(zip(*matching))\n",
1003
+ " for target_index, target_box in enumerate(target_boxes):\n",
1004
+ " if target_index in matching_dict:\n",
1005
+ " box1 = target_box\n",
1006
+ " box2 = prediction_boxes[matching_dict[target_index]]\n",
1007
+ " ious[target_index] = calculate_iou(box1, box2)\n",
1008
+ " return ious\n",
1009
+ "\n",
1010
+ "def calculate_matched_iou_mean_batch(batch_target_boxes, batch_pred_boxes, batch_matching):\n",
1011
+ " all_ious = []\n",
1012
+ " for args in zip(batch_target_boxes, batch_pred_boxes, batch_matching):\n",
1013
+ " all_ious.append(calculate_matched_ious(*args))\n",
1014
+ " return all_ious\n",
1015
+ "\n"
1016
+ ]
1017
+ },
1018
+ {
1019
+ "cell_type": "code",
1020
+ "execution_count": null,
1021
+ "id": "24d483c1",
1022
+ "metadata": {},
1023
+ "outputs": [],
1024
+ "source": [
1025
+ "from collections import defaultdict\n",
1026
+ "from tqdm import tqdm\n",
1027
+ "def model_predictor(signals):\n",
1028
+ " # Use the already loaded model and apply thresholding\n",
1029
+ " signals = reshape_to_2d(signals)\n",
1030
+ " outputs = model(signals)\n",
1031
+ " return expand_true(outputs.real > 0.5) # Use real part for thresholding\n",
1032
+ "def evaluate(predictor, data_loader, device=\"cuda\"):\n",
1033
+ " snr_metrics = defaultdict(lambda: {\n",
1034
+ " \"iou_sum\": 0.0,\n",
1035
+ " \"iou_count\": 0,\n",
1036
+ " \"recall_counts\": defaultdict(int),\n",
1037
+ " \"total_samples\": defaultdict(int),\n",
1038
+ " \"correct_pixels\": 0,\n",
1039
+ " \"total_pixels\": 0\n",
1040
+ " })\n",
1041
+ " total_iou_sum, total_iou_count = 0.0, 0\n",
1042
+ " total_correct_pixels, total_total_pixels = 0, 0\n",
1043
+ " total_recall_counts = defaultdict(int)\n",
1044
+ " total_samples = defaultdict(int)\n",
1045
+ "\n",
1046
+ " for inputs, masks, snrs_in_batch in tqdm(data_loader, desc=\"Evaluating\"):\n",
1047
+ " #inputs = inputs.to(device)\n",
1048
+ " inputs = reshape_to_2d(inputs).to(device)\n",
1049
+ " masks = masks.to(device)\n",
1050
+ " outputs = predictor(inputs)\n",
1051
+ "\n",
1052
+ " for i in range(len(snrs_in_batch)):\n",
1053
+ " snr = snrs_in_batch[i].item()\n",
1054
+ " mask = masks[i]\n",
1055
+ " output = outputs[i]\n",
1056
+ "\n",
1057
+ " # Ensure output matches mask shape\n",
1058
+ " if output.numel() != mask.numel():\n",
1059
+ " output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n",
1060
+ "\n",
1061
+ " thresholded_output = (output.real >= 0.5).float()\n",
1062
+ "\n",
1063
+ " correct_pixels = (thresholded_output == mask).sum().item()\n",
1064
+ " total_pixels = mask.numel()\n",
1065
+ " snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n",
1066
+ " snr_metrics[snr][\"total_pixels\"] += total_pixels\n",
1067
+ " total_correct_pixels += correct_pixels\n",
1068
+ " total_total_pixels += total_pixels\n",
1069
+ "\n",
1070
+ " target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n",
1071
+ " pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n",
1072
+ " if not target_boxes or not pred_boxes:\n",
1073
+ " continue\n",
1074
+ " matching = match_targets(target_boxes, pred_boxes)\n",
1075
+ " matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n",
1076
+ "\n",
1077
+ " snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n",
1078
+ " snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n",
1079
+ " total_iou_sum += sum(matched_ious)\n",
1080
+ " total_iou_count += len(matched_ious)\n",
1081
+ "\n",
1082
+ " for th in iou_thresholds:\n",
1083
+ " true_positives = sum(1 for iou in matched_ious if iou >= th)\n",
1084
+ " snr_metrics[snr][\"recall_counts\"][th] += true_positives\n",
1085
+ " snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n",
1086
+ " total_recall_counts[th] += true_positives\n",
1087
+ " total_samples[th] += len(target_boxes)\n",
1088
+ "\n",
1089
+ " # Calculate overall metrics\n",
1090
+ " overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n",
1091
+ " overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n",
1092
+ " overall_recall = {th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0 for th in iou_thresholds}\n",
1093
+ "\n",
1094
+ " # Print overall results\n",
1095
+ " print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n",
1096
+ " print(f\"Overall IoU Score: {overall_iou:.4f}\")\n",
1097
+ " for th in iou_thresholds:\n",
1098
+ " print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n",
1099
+ "\n",
1100
+ " # Print per-SNR results\n",
1101
+ " for snr, metrics in sorted(snr_metrics.items()):\n",
1102
+ " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n",
1103
+ " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n",
1104
+ " print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n",
1105
+ " print(f\" IoU: {snr_iou:.4f}\")\n",
1106
+ " for th in iou_thresholds:\n",
1107
+ " recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n",
1108
+ " print(f\" Recall at threshold {th}: {recall:.4f}\")\n",
1109
+ "\n",
1110
+ " return snr_metrics\n"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": null,
1116
+ "id": "a71c18ba",
1117
+ "metadata": {
1118
+ "scrolled": false
1119
+ },
1120
+ "outputs": [],
1121
+ "source": [
1122
+ "snr_metrics = evaluate(model_predictor, full_loader, device=device)"
1123
+ ]
1124
+ },
1125
+ {
1126
+ "cell_type": "markdown",
1127
+ "id": "87417c7b",
1128
+ "metadata": {},
1129
+ "source": [
1130
+ "### Plot and Save"
1131
+ ]
1132
+ },
1133
+ {
1134
+ "cell_type": "code",
1135
+ "execution_count": null,
1136
+ "id": "1dbfb5e6",
1137
+ "metadata": {
1138
+ "scrolled": false
1139
+ },
1140
+ "outputs": [],
1141
+ "source": [
1142
+ "import json\n",
1143
+ "import matplotlib.pyplot as plt\n",
1144
+ "from pathlib import Path\n",
1145
+ "\n",
1146
+ "# Define the path for saving the JSON file and plots\n",
1147
+ "save_path = Path(\"CMuSeNet_plots/Synthetic\")\n",
1148
+ "save_path.mkdir(parents=True, exist_ok=True)\n",
1149
+ "json_file_path = save_path / \"evaluation_results.json\"\n",
1150
+ "\n",
1151
+ "# Save metrics and plot results\n",
1152
+ "def save_and_plot_results(snr_metrics, iou_thresholds):\n",
1153
+ " # Prepare data for plotting and JSON saving\n",
1154
+ " snr_values = sorted(snr_metrics.keys())\n",
1155
+ " iou_scores = [snr_metrics[snr][\"iou_sum\"] / snr_metrics[snr][\"iou_count\"] if snr_metrics[snr][\"iou_count\"] > 0 else 0 for snr in snr_values]\n",
1156
+ " accuracies = [(snr_metrics[snr][\"correct_pixels\"] / snr_metrics[snr][\"total_pixels\"]) * 100 if snr_metrics[snr][\"total_pixels\"] > 0 else 0 for snr in snr_values]\n",
1157
+ " recalls = {th: [(snr_metrics[snr][\"recall_counts\"][th] / snr_metrics[snr][\"total_samples\"][th]) if snr_metrics[snr][\"total_samples\"][th] > 0 else 0 for snr in snr_values] for th in iou_thresholds}\n",
1158
+ "\n",
1159
+ " # Save results to JSON\n",
1160
+ " results = {\n",
1161
+ " \"SNR\": snr_values,\n",
1162
+ " \"IoU_Scores\": iou_scores,\n",
1163
+ " \"Accuracy\": accuracies,\n",
1164
+ " \"Recall\": {str(th): recalls[th] for th in iou_thresholds}\n",
1165
+ " }\n",
1166
+ " with open(json_file_path, \"w\") as f:\n",
1167
+ " json.dump(results, f, indent=4)\n",
1168
+ " print(f\"Results saved to {json_file_path}\")\n",
1169
+ "\n",
1170
+ " # Plot IoU vs SNR\n",
1171
+ " plt.figure()\n",
1172
+ " plt.plot(snr_values, iou_scores, marker='o', label=\"IoU Score\")\n",
1173
+ " plt.xlabel(\"SNR (dB)\")\n",
1174
+ " plt.ylabel(\"IoU Score\")\n",
1175
+ " plt.title(\"IoU Score vs. SNR\")\n",
1176
+ " plt.grid(True)\n",
1177
+ " plt.legend()\n",
1178
+ " plt.savefig(save_path / \"IoU_vs_SNR.png\")\n",
1179
+ " plt.savefig(save_path / \"IoU_vs_SNR.svg\")\n",
1180
+ " plt.show()\n",
1181
+ "\n",
1182
+ " # Plot Accuracy vs SNR\n",
1183
+ " plt.figure()\n",
1184
+ " plt.plot(snr_values, accuracies, marker='o', label=\"Accuracy\")\n",
1185
+ " plt.xlabel(\"SNR (dB)\")\n",
1186
+ " plt.ylabel(\"Accuracy (%)\")\n",
1187
+ " plt.title(\"Accuracy vs. SNR (Threshold 0.5)\")\n",
1188
+ " plt.grid(True)\n",
1189
+ " plt.legend()\n",
1190
+ " plt.savefig(save_path / \"Accuracy_vs_SNR.png\")\n",
1191
+ " plt.savefig(save_path / \"Accuracy_vs_SNR.svg\")\n",
1192
+ " plt.show()\n",
1193
+ "\n",
1194
+ " # Plot Recall vs SNR for each threshold\n",
1195
+ " for th in iou_thresholds:\n",
1196
+ " plt.figure()\n",
1197
+ " plt.plot(snr_values, recalls[th], marker='o', label=f\"Recall at {th}\")\n",
1198
+ " plt.xlabel(\"SNR (dB)\")\n",
1199
+ " plt.ylabel(\"Recall\")\n",
1200
+ " plt.title(f\"Recall vs. SNR (Threshold {th})\")\n",
1201
+ " plt.grid(True)\n",
1202
+ " plt.legend()\n",
1203
+ " plt.savefig(save_path / f\"Recall_vs_SNR_{th}.png\")\n",
1204
+ " plt.savefig(save_path / f\"Recall_vs_SNR_{th}.svg\")\n",
1205
+ " plt.show()\n",
1206
+ "\n",
1207
+ "# Call this after running evaluate() to save and plot results\n",
1208
+ "save_and_plot_results(snr_metrics, iou_thresholds)"
1209
+ ]
1210
+ },
1211
+ {
1212
+ "cell_type": "code",
1213
+ "execution_count": null,
1214
+ "id": "d0c0d3e8",
1215
+ "metadata": {},
1216
+ "outputs": [],
1217
+ "source": []
1218
+ }
1219
+ ],
1220
+ "metadata": {
1221
+ "kernelspec": {
1222
+ "display_name": "Python 3 (ipykernel)",
1223
+ "language": "python",
1224
+ "name": "python3"
1225
+ },
1226
+ "language_info": {
1227
+ "codemirror_mode": {
1228
+ "name": "ipython",
1229
+ "version": 3
1230
+ },
1231
+ "file_extension": ".py",
1232
+ "mimetype": "text/x-python",
1233
+ "name": "python",
1234
+ "nbconvert_exporter": "python",
1235
+ "pygments_lexer": "ipython3",
1236
+ "version": "3.10.9"
1237
+ }
1238
+ },
1239
+ "nbformat": 4,
1240
+ "nbformat_minor": 5
1241
+ }
CMuSeNet_Synthetic_IQ_Generator/README.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This matlab function is generated and tested in MATLAB 2022, and 2024
2
+
3
+ Please open datagen.m script and run it with MATLAB to generate synthetic dataset with same configuration as CMuSeNet synthetic dataset.
4
+
5
+ In this script you can change various setting such as channel (AWGN, Rician, Rayleigh), sample speed, range of SNR and sample bandwidth.
6
+
7
+ This dataset is used to train CMuSeNet, complex-valued multi-signla segmentation Network.
8
+
9
+ Please cite our paper if you use this dataset or synthetic dataset generation script.
10
+
11
+ @inproceedings{shin2025cmusenet,
12
+ title={I Can't Believe It's Not Real: {CV-MuSeNet}: Complex-Valued Multi-Signal Segmentation},
13
+ author={Sangwon Shin and Mehmet C. Vuran},
14
+ booktitle={IEEE Dynamic Spectrum Access Networks (DySPAN)},
15
+ year={2025},
16
+ organization={IEEE}
17
+ }
18
+
19
+ Acknowledgement: Office of Naval Research, NSWC Crane N00174-23-1-0007
20
+ This work relates to Department of Navy award N00174-23-1-0007 issued by the Office of Naval Research, NSWC Crane. Any opinions,
21
+ findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of the Office of Naval Research.
22
+
23
+ Following IQ samples generation script is coded by Prashant Subedi, Sangwon Shin and Dr. Mehmet Can Vuran - Cyber Physical Networking (CPN) Lab at University of Nebraska - Lincoln
24
+
25
+ License:
26
+ This IQ samples is licensed under the GPL family (General Public License) terms.
CMuSeNet_Synthetic_IQ_Generator/datagen.m ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path = "../diff-snr-matlab-simulated-data";
2
+
3
+ for snr = -20:2:10
4
+ disp(snr);
5
+ mkdir(sprintf("%s/%d/", path, snr));
6
+ for i = 1:5000
7
+ name = string(i);
8
+ channelType = 'awgn'; %Supported channel type: awgn, rician (Flat), rayleigh (Flat)
9
+ [meta, data] = datagenWideband(snr, channelType);
10
+ split = reshape([real(data) imag(data)].', 1, []);
11
+
12
+ % Save data file
13
+ mkdir(sprintf("%s/%d/%s", path, snr, name));
14
+
15
+
16
+ datafile = fopen(sprintf("%s/%d/%s/data.dat", path, snr, name), 'w');
17
+ fwrite(datafile, split, 'double');
18
+ fclose(datafile);
19
+
20
+ % Save meta file
21
+ metafile = fopen(sprintf("%s/%d/%s/meta-data.json", path, snr, name), 'w');
22
+ fprintf(metafile, jsonencode(meta));
23
+ fclose(metafile);
24
+
25
+
26
+ disp(name);
27
+ end
28
+ end
CMuSeNet_Synthetic_IQ_Generator/datagenTransmitter.m ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function transmittedSignal = datagenTransmitter( ...
2
+ modulation, ...
3
+ rolloffFactor, ...
4
+ filterSpanInSymbols, ...
5
+ samplesPerSymbol, ...
6
+ symbolRate, ...
7
+ messageDuration ...
8
+ )
9
+ requiresFilter = true;
10
+ if modulation == "QPSK"
11
+ bitsPerSymbol = 2;
12
+ modulator = comm.QPSKModulator( ...
13
+ 'BitInput', true, ...
14
+ 'PhaseOffset', pi/4, ...
15
+ 'OutputDataType', 'double' ...
16
+ );
17
+ elseif modulation == "BPSK"
18
+ bitsPerSymbol = 1;
19
+ modulator = comm.BPSKModulator;
20
+ elseif modulation == "8-PSK"
21
+ bitsPerSymbol = 3;
22
+ modulator = @(x) qammod(bit2int(x, 3), 8);
23
+ elseif modulation == "8-QAM"
24
+ bitsPerSymbol = 3;
25
+ modulator = @(x) pskmod(bit2int(x, 3), 8);
26
+ elseif modulation == "16-QAM"
27
+ bitsPerSymbol = 4;
28
+ modulator = @(x) qammod(bit2int(x, 4), 16);
29
+ elseif modulation == "GMSK"
30
+ bitsPerSymbol = 1;
31
+ modulator = comm.GMSKModulator("SamplesPerSymbol", samplesPerSymbol, ...
32
+ "BitInput", true);
33
+ requiresFilter = false;
34
+ elseif modulation == "2-FSK"
35
+ bitsPerSymbol = 1;
36
+ fdev = floor(symbolRate/4);
37
+ samplesPerSymbol = 8;
38
+ modulator = @(x) fskmod(x, 2, fdev, samplesPerSymbol, symbolRate);
39
+ requiresFilter = false;
40
+ else
41
+ error("Not implemented " + modulation);
42
+ end
43
+
44
+ transmittedBin = randi( ...
45
+ [0 1], ...
46
+ bitsPerSymbol * symbolRate * messageDuration/samplesPerSymbol, ...
47
+ 1 ...
48
+ );
49
+
50
+ modulatedData = modulator(transmittedBin); % Modulates the bits into QPSK symbols
51
+
52
+ if requiresFilter
53
+ transmitterFilter = comm.RaisedCosineTransmitFilter( ...
54
+ 'RolloffFactor', rolloffFactor, ...
55
+ 'FilterSpanInSymbols', filterSpanInSymbols, ...
56
+ 'OutputSamplesPerSymbol', samplesPerSymbol ...
57
+ );
58
+ transmittedSignal = transmitterFilter(modulatedData); % Square root Raised Cosine Transmit Filter
59
+ else
60
+ transmittedSignal = modulatedData;
61
+ end
62
+
63
+
64
+ end
CMuSeNet_Synthetic_IQ_Generator/datagenWideband.m ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function [metadata, widebandSignal] = datagenWideband(SNRdB, fadingType)
2
+ % Constant for this function
3
+ RolloffFactor = 0.35;
4
+ RaisedCosineFilterSpan = 10;
5
+ Interpolation = 2;
6
+ NarrowBandBWs = [1e5, 2e5, 5e5, 1e6, 2e6];
7
+ WideBandBW = 20e6;
8
+ MaxSignals = 10;
9
+ % Modulations = ["QPSK" "BPSK" "8-PSK" "8-QAM" "16-QAM" "2-FSK" ];
10
+ Modulations = ["QPSK" "BPSK" "8-PSK" "8-QAM" "16-QAM", "GMSK", "2-FSK"];
11
+ SamplingTime = 2/1000; % 2ms
12
+
13
+ TxPowerRange = [0, 20];
14
+
15
+ numberOfSignals = randi([1, MaxSignals], 1);
16
+
17
+ signalBW = randsample(NarrowBandBWs, numberOfSignals, true);
18
+ txPowers = randi(TxPowerRange, [numberOfSignals, 1]);
19
+
20
+
21
+ minGap = 100e3; % 100kHz
22
+
23
+ maxBW = max(signalBW);
24
+
25
+ % Allocate a space for the frequencies
26
+ freqOffsets = [];
27
+ usedFreqs = [];
28
+ % A mechanism to prevent it from being stuck if there are too many
29
+ % wideband signals
30
+ maxLoops = numberOfSignals * 10;
31
+ % Generate non-overlapping frequencies
32
+ for i = 1:numberOfSignals
33
+ bw = signalBW(i);
34
+ % Generate a random frequency offset within the limits
35
+ while maxLoops > 0
36
+ maxLoops = maxLoops - 1; % prevent it from handing
37
+ freq = randi([-WideBandBW/2 + bw/2, WideBandBW/2 - bw/2]);
38
+ % Check if the frequency space for the new signal is already occupied or
39
+ % if the new signal is within minGap of an existing signal
40
+ overlap = false;
41
+ for j = 1:length(usedFreqs)
42
+ existing_bw = signalBW(j);
43
+ if abs(freq - usedFreqs(j)) < (bw + existing_bw)/2 + minGap
44
+ overlap = true;
45
+ break;
46
+ end
47
+ end
48
+ if ~overlap
49
+ % If not, add the frequency to the used frequencies and break the loop
50
+ usedFreqs = [usedFreqs freq];
51
+ freqOffsets = [freqOffsets freq];
52
+ break
53
+ end
54
+ % If the frequency space is occupied or too close to another signal,
55
+ % generate a new random frequency
56
+ end
57
+
58
+ if maxLoops <= 0
59
+ numberOfSignals = length(freqOffsets);
60
+ disp("Stopping because couldn't place signal");
61
+ disp(signalBW);
62
+ break;
63
+ end
64
+
65
+ end
66
+
67
+
68
+ signals = [];
69
+ metadata = [];
70
+
71
+ lowestPowerSignal = min(txPowers);
72
+ noisePower = min(txPowers) - SNRdB;
73
+
74
+
75
+ for i = 1: numberOfSignals
76
+ modulation = randsample(Modulations, 1);
77
+ txPower = txPowers(i);
78
+ bw = signalBW(i);
79
+ % Should the divisor be 20 ?
80
+ signal = datagenTransmitter( ...
81
+ modulation, ...
82
+ RolloffFactor, ...
83
+ RaisedCosineFilterSpan, ...
84
+ Interpolation, ...
85
+ bw, ...
86
+ SamplingTime...811
87
+ );
88
+
89
+ % Scale the signal
90
+ signal = signal/sqrt(mean(abs(signal).^2));
91
+
92
+ % Scale to correct power
93
+ signal = 10^(txPower/20)*signal;
94
+
95
+ pwr = 10*log10(mean(abs(signal).^2));
96
+
97
+
98
+ if bw ~= maxBW
99
+ signal = resample(signal, maxBW/1e5, bw/1e5);
100
+ end
101
+ signals = [signals signal];
102
+ metadata = [metadata; struct("fc", freqOffsets(i), "bw", bw, "mod", modulation, "txPower", txPower, "noisePower", noisePower)];
103
+
104
+ end
105
+ mbc = comm.MultibandCombiner( ...
106
+ InputSampleRate=maxBW, ...
107
+ FrequencyOffsets=freqOffsets, ...
108
+ OutputSampleRateSource="property", ...
109
+ OutputSampleRate=WideBandBW ...
110
+ );
111
+
112
+ combinedsig = mbc(signals);
113
+ % Channel configuration
114
+ fd = 30; % Max Doppler shift in Hz
115
+ Ts = 1/WideBandBW; % Sampling time
116
+ chan = [];
117
+
118
+ switch lower(fadingType)
119
+ case 'awgn'
120
+ % Just noise without fading
121
+ widebandSignal = awgn(combinedsig, SNRdB, lowestPowerSignal);
122
+
123
+ case 'rayleigh'
124
+ rayleighChan = comm.RayleighChannel( ...
125
+ 'SampleRate', WideBandBW, ...
126
+ 'PathDelays', 0, ...
127
+ 'AveragePathGains', 0, ...
128
+ 'MaximumDopplerShift', 30 ...
129
+ );
130
+ fadedSignal = rayleighChan(combinedsig);
131
+ widebandSignal = awgn(fadedSignal, SNRdB, lowestPowerSignal); % Add AWGN
132
+
133
+ case 'rician'
134
+ ricianChan = comm.RicianChannel( ...
135
+ 'SampleRate', WideBandBW, ...
136
+ 'PathDelays', 0, ...
137
+ 'AveragePathGains', 0, ...
138
+ 'KFactor', 10, ...
139
+ 'MaximumDopplerShift', 30 ...
140
+ );
141
+ fadedSignal = ricianChan(combinedsig);
142
+ widebandSignal = awgn(fadedSignal, SNRdB, lowestPowerSignal); % Add AWGN
143
+
144
+ otherwise
145
+ error('Unsupported fading type: %s', fadingType);
146
+ end
147
+ end