{ "cells": [ { "cell_type": "markdown", "id": "b5007b71", "metadata": {}, "source": [ "### Initialization" ] }, { "cell_type": "code", "execution_count": null, "id": "3e6b1226", "metadata": {}, "outputs": [], "source": [ "### Initialization block\n", "from pathlib import Path\n", "import numpy as np\n", "import json\n", "import torch\n", "import numpy as np\n", "from tqdm import tqdm\n", "import math\n", "from torch.utils.data import DataLoader, TensorDataset\n", "\n", "STFT_LENGTH = 16 * 1024\n", "DATA_DIR = Path(\"dataset/\")\n", "SAMPLE_RATE = 20e6\n", "MODULATIONS = [\"QPSK\", \"BPSK\", \"8-PSK\", \"8-QAM\", \"16-QAM\", \"GMSK\", \"2-FSK\"]\n", "MODULATION_LABELS = {j: i for i, j in enumerate(MODULATIONS)}\n", "NUMBER_OF_MODULATIONS = len(MODULATIONS)\n", "\n", "def load_data(snr, name, load_metadata_only=False):\n", " if not load_metadata_only:\n", " with open(DATA_DIR/str(snr)/str(name)/\"data.dat\", \"rb\") as f:\n", " signal = np.fromfile(f, dtype=np.complex128)\n", " else:\n", " signal = None\n", " with open(DATA_DIR/str(snr)/str(name)/\"meta-data.json\") as f:\n", " meta = json.load(f)\n", " if type(meta) == dict:\n", " meta = [meta]\n", " return signal, meta\n", "\n", " \n", "def _get_all_numbered_dirs(root_dir):\n", " dirs = []\n", " for directory in root_dir.iterdir():\n", " dirs.append(int(directory.name))\n", " dirs.sort()\n", " return dirs\n", "\n", "def get_signals(snr):\n", " return _get_all_numbered_dirs(Path(DATA_DIR)/str(snr))\n", "\n", "\n", "def get_snrs(root_dir=DATA_DIR):\n", " return _get_all_numbered_dirs(root_dir)\n", " \n", " \n", "def process_metadata(metadata):\n", " scaled_metadata = [\n", " {\n", " \"position\": (SAMPLE_RATE/2 + i['fc'], i['bw']),\n", " \"mod\": i[\"mod\"]\n", " }\n", " for i in metadata\n", " ]\n", " return scaled_metadata\n", "\n", "\n", "def process_signal(signal):\n", " signal = signal[:STFT_LENGTH]\n", "\n", " signal = np.fft.fft(signal)\n", " signal = np.fft.fftshift(signal)\n", " signal /= np.max(np.abs(signal))\n", " \n", " #return np.expand_dims(signal, axis=0)\n", " return signal" ] }, { "cell_type": "markdown", "id": "440b802c", "metadata": {}, "source": [ "### Data Loading" ] }, { "cell_type": "code", "execution_count": null, "id": "31bc3770", "metadata": {}, "outputs": [], "source": [ "MASK_SIZE = int(STFT_LENGTH)\n", "\n", "class WidebandSignalDataset(torch.utils.data.Dataset):\n", " def __init__(self, signal_ids, mask_size=MASK_SIZE, return_snr=False):\n", " self.mask_size = mask_size\n", " self.signal_ids = signal_ids\n", " self.return_snr = return_snr # New parameter to control SNR return\n", " loaded_data = []\n", " for snr, signal_id in tqdm(self.signal_ids):\n", " signal, masks = self.process_signal(snr, signal_id)\n", " loaded_data.append((signal, masks))\n", " self.loaded_data = loaded_data\n", "\n", " def __len__(self):\n", " return len(self.signal_ids)\n", "\n", " def __getitem__(self, index):\n", " signal, masks = self.loaded_data[index]\n", " if self.return_snr:\n", " snr, _ = self.signal_ids[index]\n", " return signal, masks, snr # Return SNR during evaluation\n", " else:\n", " return signal, masks # Return only signal and masks during training\n", "\n", " def process_signal(self, snr, signal_id):\n", " signal, metadata = load_data(snr, signal_id)\n", " scaled_metadata = process_metadata(metadata)\n", " signal = process_signal(signal)\n", " signal = torch.from_numpy(signal)\n", " masks = torch.zeros(self.mask_size)\n", " scale_ratio = self.mask_size / SAMPLE_RATE\n", " for meta in scaled_metadata:\n", " f, b = meta['position']\n", " x1, x2 = math.floor((f - b / 2) * scale_ratio), math.ceil((f + b / 2) * scale_ratio)\n", " masks[x1:x2] = 1\n", " return signal.type(torch.complex64), masks.type(torch.FloatTensor)\n", "\n", "# Train test split 80 - 10 - 10\n", "train, test, validation = [], [], [] \n", "for snr in get_snrs():\n", " signals = get_signals(snr)\n", " total_signals = len(signals)\n", " for signal in signals:\n", " if signal <= 0.8 * total_signals:\n", " train.append((snr, signal))\n", " elif signal <= 0.9 * total_signals:\n", " validation.append((snr, signal))\n", " else:\n", " test.append((snr, signal))\n", " \n", "print(\"Train\", len(train))\n", "print(\"Validation\", len(validation))\n", "print(\"Test\", len(test))\n", "\n", "train_dataset = WidebandSignalDataset(signal_ids=train)\n", "validation_dataset = WidebandSignalDataset(signal_ids=validation)\n", "test_dataset = WidebandSignalDataset(signal_ids=test)" ] }, { "cell_type": "markdown", "id": "637ae774", "metadata": {}, "source": [ "### Batch Loading" ] }, { "cell_type": "code", "execution_count": null, "id": "a9af2450", "metadata": {}, "outputs": [], "source": [ "batch_size = 64 # Updated batch size\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "valid_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)\n", "\n", "print(\"Train labels shape:\", len(train_dataset))\n", "print(\"Validation labels shape:\", len(validation_dataset))" ] }, { "cell_type": "markdown", "id": "9a8e09e4", "metadata": {}, "source": [ "### Early Stop" ] }, { "cell_type": "code", "execution_count": null, "id": "24f79a24", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "class EarlyStopping:\n", " def __init__(self, patience=10, verbose=False, delta=0.0001, save_path='./models/CMuSeNet'):\n", " self.patience = patience\n", " self.verbose = verbose\n", " self.delta = delta\n", " self.counter = 0\n", " self.best_score = None\n", " self.early_stop = False\n", " self.val_loss_min = float('inf')\n", " self.best_model = None\n", " self.save_path = save_path\n", " os.makedirs(save_path, exist_ok=True)\n", " \n", " def __call__(self, val_loss, model):\n", " score = -val_loss\n", "\n", " if self.best_score is None:\n", " self.best_score = score\n", " self.save_checkpoint(val_loss, model)\n", " elif score < self.best_score + self.delta:\n", " self.counter += 1\n", " if self.verbose:\n", " print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n", " if self.counter >= self.patience:\n", " self.early_stop = True\n", " else:\n", " self.best_score = score\n", " self.save_checkpoint(val_loss, model)\n", " self.counter = 0\n", "\n", " def save_checkpoint(self, val_loss, model):\n", " if self.verbose:\n", " print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n", " self.val_loss_min = val_loss\n", " self.best_model = model.state_dict()\n", " save_path = os.path.join(self.save_path, 'best_model.pth')\n", " torch.save(self.best_model, save_path)" ] }, { "cell_type": "markdown", "id": "6c3fda74", "metadata": {}, "source": [ "### Reshape" ] }, { "cell_type": "code", "execution_count": null, "id": "5fcf91db", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import complexPyTorch.complexLayers as cplx\n", "import torch.nn.functional as F\n", "import torch\n", "\n", "def reshape_to_2d(data):\n", " return data.view(-1, 1, 128, 128) # Reshape to [batch, channels, height, width]" ] }, { "cell_type": "markdown", "id": "b7d7562c", "metadata": {}, "source": [ "### Complex IoU" ] }, { "cell_type": "code", "execution_count": null, "id": "7218c3f3", "metadata": {}, "outputs": [], "source": [ "def calculate_iou(pred, target, threshold=0.5):\n", " real_pred = (pred.real > threshold).float()\n", " imag_pred = (pred.imag > threshold).float()\n", " \n", " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n", " \n", " intersection = (combined_pred * target).sum(dim=1)\n", " union = (combined_pred + target).sum(dim=1) - intersection\n", " iou = (intersection / union).mean().item()\n", " return iou" ] }, { "cell_type": "markdown", "id": "64f4063c", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": null, "id": "66825110", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "def validate_model(model, valid_loader, criterion):\n", " model.eval()\n", " running_loss = 0.0\n", " iou_scores = []\n", " total_correct = 0\n", " total_samples = 0\n", "\n", " with torch.no_grad():\n", " for inputs, masks in tqdm(valid_loader, desc=\"Validating\"):\n", " inputs = reshape_to_2d(inputs).to(device)\n", " masks = masks.to(device)\n", " outputs = model(inputs)\n", " loss = criterion(outputs, masks)\n", " running_loss += loss.item()\n", "\n", " # Calculate IoU\n", " iou = calculate_iou(outputs, masks, threshold=0.5)\n", " iou_scores.append(iou)\n", " \n", " # Calculate accuracy\n", " preds = ((outputs.real > 0.5) & (outputs.imag > 0.5)).float()\n", " correct = (preds == masks).float().sum()\n", " total_correct += correct.item()\n", " total_samples += masks.numel()\n", "\n", " val_loss = running_loss / len(valid_loader)\n", " mean_iou = sum(iou_scores) / len(iou_scores)\n", " accuracy = total_correct / total_samples * 100\n", "\n", " print(f'Validation Loss: {val_loss:.6f}')\n", " print(f'Validation Accuracy: {accuracy:.2f}%')\n", "\n", " return val_loss, accuracy\n", "\n", "def train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.00001], num_epochs=50, patience=3):\n", " train_losses = []\n", " val_losses = []\n", " val_accuracies = []\n", " epoch_durations = []\n", " \n", " current_lr = initial_lr\n", " for lr in lr_steps:\n", " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", " early_stopping = EarlyStopping(patience=patience, verbose=True, delta=0.001)\n", " print(\"Current learning rate: \", lr)\n", " for epoch in range(num_epochs):\n", " epoch_start_time = time.time()\n", " \n", " model.train()\n", " running_loss = 0.0\n", " for inputs, masks in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs} - Training\"):\n", " inputs = reshape_to_2d(inputs).to(device)\n", " masks = masks.to(device)\n", " outputs = model(inputs)\n", " loss = criterion(outputs, masks)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss.item()\n", "\n", " epoch_loss = running_loss / len(train_loader)\n", " train_losses.append(epoch_loss)\n", " print(f\"Training Loss: {epoch_loss:.6f}\")\n", "\n", " val_loss, val_accuracy = validate_model(model, valid_loader, criterion)\n", " val_losses.append(val_loss)\n", " val_accuracies.append(val_accuracy)\n", " early_stopping(val_loss, model)\n", "\n", " if early_stopping.early_stop:\n", " print(\"Early stopping triggered\")\n", " break\n", "\n", " epoch_duration = time.time() - epoch_start_time\n", " epoch_durations.append(epoch_duration)\n", " if early_stopping.best_model is not None:\n", " print(f\"Loading best model from lr {lr}\")\n", " model.load_state_dict(early_stopping.best_model)\n", " \n", " print(\"Training completed.\")\n", " print(\"Epoch durations:\", epoch_durations)\n", " return model, train_losses, val_losses, val_accuracies, epoch_durations" ] }, { "cell_type": "markdown", "id": "0b80cb51", "metadata": {}, "source": [ "### ResNet-18" ] }, { "cell_type": "code", "execution_count": null, "id": "2d208cb9", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import complexPyTorch.complexLayers as cplx\n", "from typing import Optional, Callable, Type, Union, List\n", "import torch.nn.functional as F\n", "from torch import Tensor\n", "\n", "def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", "\n", "def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> cplx.ComplexConv2d:\n", " \"\"\"1x1 convolution\"\"\"\n", " return cplx.ComplexConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n", "\n", "class BasicBlock(nn.Module):\n", " expansion = 1\n", "\n", " def __init__(\n", " self,\n", " inplanes: int,\n", " planes: int,\n", " stride: int = 1,\n", " downsample: Optional[nn.Module] = None,\n", " norm_layer: Optional[Callable[..., nn.Module]] = None,\n", " ) -> None:\n", " super(BasicBlock, self).__init__()\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n", " self.relu = cplx.ComplexReLU()\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(\n", " self,\n", " inplanes: int,\n", " planes: int,\n", " stride: int = 1,\n", " downsample: Optional[nn.Module] = None,\n", " norm_layer: Optional[Callable[..., nn.Module]] = None,\n", " ) -> None:\n", " super(Bottleneck, self).__init__()\n", " self.conv1 = conv1x1(inplanes, planes)\n", " self.bn1 = cplx.ComplexBatchNorm2d(planes)\n", " self.conv2 = conv3x3(planes, planes, stride)\n", " self.bn2 = cplx.ComplexBatchNorm2d(planes)\n", " self.conv3 = conv1x1(planes, planes * self.expansion)\n", " self.bn3 = cplx.ComplexBatchNorm2d(planes * self.expansion)\n", " self.relu = cplx.ComplexReLU()\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "class ComplexResNet(nn.Module):\n", " def __init__(\n", " self,\n", " block: Type[Union[BasicBlock, Bottleneck]],\n", " layers: List[int],\n", " num_classes: int = STFT_LENGTH,\n", " zero_init_residual: bool = False,\n", " groups: int = 1,\n", " width_per_group: int = 64,\n", " norm_layer: Optional[Callable[..., nn.Module]] = None,\n", " ) -> None:\n", " super(ComplexResNet, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = cplx.ComplexBatchNorm2d\n", " self._norm_layer = norm_layer\n", "\n", " self.inplanes = 64\n", " self.dilation = 1\n", "\n", " self.groups = groups\n", " self.base_width = width_per_group\n", " self.conv1 = cplx.ComplexConv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)\n", " self.bn1 = norm_layer(self.inplanes)\n", " self.relu = cplx.ComplexReLU()\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc = cplx.ComplexLinear(512 * block.expansion, num_classes)\n", " self.sigmoid = cplx.ComplexSigmoid()\n", "\n", " def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n", " norm_layer = self._norm_layer\n", " downsample = None\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " conv1x1(self.inplanes, planes * block.expansion, stride),\n", " norm_layer(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n", " self.inplanes = planes * block.expansion\n", " for _ in range(1, blocks):\n", " layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def _forward_impl(self, x: Tensor) -> Tensor:\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc(x)\n", " x = self.sigmoid(x)\n", " return x\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " return self._forward_impl(x)\n", "\n", "def ComplexResNet18():\n", " return ComplexResNet(BasicBlock, [2, 2, 2, 2])\n", "\n", "# Create the model instance\n", "model = ComplexResNet18()\n", "print(model)\n" ] }, { "cell_type": "markdown", "id": "e4bc1b5d", "metadata": {}, "source": [ "### Complex focal Loss" ] }, { "cell_type": "code", "execution_count": null, "id": "61c29429", "metadata": {}, "outputs": [], "source": [ "class ComplexFocalLoss(nn.Module):\n", " def __init__(self, alpha=1, gamma=2, reduction='mean'):\n", " super(ComplexFocalLoss, self).__init__()\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.reduction = reduction\n", "\n", " def forward(self, inputs, targets):\n", " real_inputs = inputs.real\n", " imag_inputs = inputs.imag\n", " \n", " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction='none')\n", " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction='none')\n", " \n", " real_pt = torch.exp(-real_BCE_loss)\n", " imag_pt = torch.exp(-imag_BCE_loss)\n", " \n", " real_F_loss = self.alpha * (1 - real_pt) ** self.gamma * real_BCE_loss\n", " imag_F_loss = self.alpha * (1 - imag_pt) ** self.gamma * imag_BCE_loss\n", "\n", " if self.reduction == 'mean':\n", " return (torch.mean(real_F_loss) + torch.mean(imag_F_loss)) / 2\n", " elif self.reduction == 'sum':\n", " return torch.sum(real_F_loss) + torch.sum(imag_F_loss)\n", " else:\n", " return real_F_loss + imag_F_loss\n", "\n", "# Update the IoU calculation to handle complex values\n", "def calculate_iou(pred, target, threshold=0.5):\n", " real_pred = (pred.real > threshold).float()\n", " imag_pred = (pred.imag > threshold).float()\n", " \n", " combined_pred = torch.logical_or(real_pred, imag_pred).float()\n", " \n", " intersection = (combined_pred * target).sum(dim=1)\n", " union = (combined_pred + target).sum(dim=1) - intersection\n", " iou = (intersection / union).mean().item()\n", " return iou" ] }, { "cell_type": "markdown", "id": "abb35ba2", "metadata": {}, "source": [ "### Training with complex focal loss" ] }, { "cell_type": "code", "execution_count": null, "id": "86d7526b", "metadata": {}, "outputs": [], "source": [ "# Initialize and train the CResNet-18 model\n", "model = ComplexResNet18().to(device)\n", "criterion = ComplexFocalLoss()\n", "\n", "# Train the model and validate it\n", "#0.001, 0.0001, 0.00001, 0.000001\n", "model, train_losses, val_losses, val_accuracies, epoch_durations =train_model(model, train_loader, valid_loader, criterion, initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3)\n", "combined_epoch_time = sum(epoch_durations)\n", "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")" ] }, { "cell_type": "markdown", "id": "fd0c9d58", "metadata": {}, "source": [ "### CVNN RV-BCE and CV-BCE Loss function implementation" ] }, { "cell_type": "code", "execution_count": null, "id": "99c736b8", "metadata": {}, "outputs": [], "source": [ "# RV BCE Loss Function Definition\n", "class RealValuedBCELoss(nn.Module):\n", " def __init__(self, reduction='mean'):\n", " super(RealValuedBCELoss, self).__init__()\n", " self.reduction = reduction\n", "\n", " def forward(self, inputs, targets):\n", " # Use only the real part of the complex inputs\n", " real_inputs = inputs.real\n", " BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n", " return BCE_loss\n", "\n", " \n", "# CV BCE Loss Function Definition\n", "class ComplexValuedBCELoss(nn.Module):\n", " def __init__(self, reduction='mean'):\n", " super(ComplexValuedBCELoss, self).__init__()\n", " self.reduction = reduction\n", "\n", " def forward(self, inputs, targets):\n", " real_inputs = inputs.real\n", " imag_inputs = inputs.imag\n", "\n", " # Calculate binary cross-entropy for both real and imaginary parts\n", " real_BCE_loss = F.binary_cross_entropy(real_inputs, targets, reduction=self.reduction)\n", " imag_BCE_loss = F.binary_cross_entropy(imag_inputs, targets, reduction=self.reduction)\n", " \n", " # Combine the losses (you can adjust the weighting if necessary)\n", " combined_BCE_loss = (real_BCE_loss + imag_BCE_loss) / 2\n", " return combined_BCE_loss" ] }, { "cell_type": "markdown", "id": "d6930f39", "metadata": {}, "source": [ "### RV-BCE Training" ] }, { "cell_type": "code", "execution_count": null, "id": "9e59d4c9", "metadata": {}, "outputs": [], "source": [ "# Set the criterion for RV BCE\n", "criterion = RealValuedBCELoss()\n", "\n", "# Train the ResNet-18 model with RV BCE\n", "device = torch.device('cuda')\n", "model = ComplexResNet18().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "\n", "# Start training with the previously defined train_model function\n", "model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n", " model, train_loader, valid_loader, criterion, \n", " initial_lr=0.001, lr_steps=[0.001, 0.0001, 0.00001, 0.000001], num_epochs=50, patience=3\n", ")\n" ] }, { "cell_type": "markdown", "id": "93d19ea7", "metadata": {}, "source": [ "### CV-BCE Training" ] }, { "cell_type": "code", "execution_count": null, "id": "2c56d5b4", "metadata": {}, "outputs": [], "source": [ "# Set the criterion for CV BCE\n", "criterion = ComplexValuedBCELoss()\n", "\n", "# Train the ResNet-18 model with CV BCE\n", "device = torch.device('cuda')\n", "model = ComplexResNet18().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "\n", "# Start training with the previously defined train_model function\n", "model, train_losses, val_losses, val_accuracies, epoch_durations = train_model(\n", " model, train_loader, valid_loader, criterion, \n", " initial_lr=0.001, lr_steps=[0.001, 0.0001], num_epochs=50, patience=3\n", ")\n", "combined_epoch_time = sum(epoch_durations)\n", "print(f\"Total time spent in epochs: {combined_epoch_time:.2f} seconds.\")" ] }, { "cell_type": "markdown", "id": "f4f6530e", "metadata": {}, "source": [ "### Plot training result (Accuracy, loss vs epoch)" ] }, { "cell_type": "code", "execution_count": null, "id": "43676a01", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import json\n", "import os\n", "\n", "# Ensure the directory exists\n", "output_dir = 'cvnn_results/segmentation'\n", "os.makedirs(output_dir, exist_ok=True)\n", "\n", "def save_metrics_to_json(train_losses, val_accuracies, epoch_durations, filename):\n", " \"\"\"\n", " Save the training losses and validation accuracies to a JSON file.\n", " \n", " Args:\n", " train_losses (list): List of training losses.\n", " val_accuracies (list): List of validation accuracies.\n", " filename (str): The file name for the JSON file.\n", " \"\"\"\n", " metrics = {\n", " \"train_losses\": train_losses,\n", " \"val_accuracies\": val_accuracies,\n", " \"epoch_durations\": epoch_durations\n", " }\n", " with open(os.path.join(output_dir, filename), 'w') as f:\n", " json.dump(metrics, f)\n", "\n", "def plot_training_metrics(train_losses, val_accuracies, plot_filename):\n", " \"\"\"\n", " Plot the training loss and validation accuracy, and mark the epoch where accuracy reaches 99%.\n", " \n", " Args:\n", " train_losses (list): List of training losses.\n", " val_accuracies (list): List of validation accuracies.\n", " plot_filename (str): The file name for saving the plot as SVG.\n", " \"\"\"\n", " epochs = range(1, len(train_losses) + 1)\n", "\n", " plt.figure(figsize=(14, 6))\n", "\n", " # Plot Training Loss\n", " plt.subplot(1, 2, 1)\n", " plt.plot(epochs, train_losses, label='Training Loss')\n", " plt.xlabel('Epochs')\n", " plt.ylabel('Loss')\n", " plt.title('Training Loss')\n", " plt.legend()\n", "\n", " # Plot Validation Accuracy\n", " plt.subplot(1, 2, 2)\n", " plt.plot(epochs, val_accuracies, label='Validation Accuracy')\n", " plt.xlabel('Epochs')\n", " plt.ylabel('Accuracy (%)')\n", " plt.title('Validation Accuracy')\n", " plt.legend()\n", "\n", " # Find the first epoch where validation accuracy reaches or exceeds 99%\n", " for i, acc in enumerate(val_accuracies):\n", " if acc >= 99:\n", " first_99_epoch = i + 1 # Epochs are 1-based\n", " plt.axvline(first_99_epoch, color='r', linestyle='--', label=f'99% reached at epoch {first_99_epoch}')\n", " break\n", "\n", " plt.legend()\n", " plt.tight_layout()\n", "\n", " # Save the plot as an SVG file\n", " plt.savefig(os.path.join(output_dir, plot_filename), format='svg')\n", " plt.show()\n", "\n", "# Save the metrics to JSON in cvnn_results/segmentation\n", "save_metrics_to_json(train_losses, val_accuracies, epoch_durations, 'training_metrics.json')\n", "\n", "# Plot the metrics and highlight when accuracy reaches 99%, saving the plot as SVG\n", "plot_training_metrics(train_losses, val_accuracies, 'training_metrics_plot.svg')" ] }, { "cell_type": "markdown", "id": "c6f4ea75", "metadata": {}, "source": [ "### Evaluation " ] }, { "cell_type": "code", "execution_count": null, "id": "a303080e", "metadata": {}, "outputs": [], "source": [ "# Load the pre-trained model for evaluation\n", "import torch\n", "\n", "device = \"cuda\"\n", "\n", "model_path = \"path/to/the/model\" #Please change this to the model path you trained\n", "model = ComplexResNet18().to(device)\n", "model.load_state_dict(torch.load(model_path, map_location=device))\n", "model.eval()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0590b6ef", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from tqdm import tqdm\n", "from torch.utils.data import DataLoader\n", "import numpy as np\n", "\n", "# Define thresholds for recall calculation\n", "iou_thresholds = [0.5, 0.7, 0.9]\n", "\n", "# Initialize metrics\n", "snr_results = {}\n", "total_accuracy = 0.0\n", "total_samples = 0\n", "iou_scores = {th: 0.0 for th in iou_thresholds}\n", "recall_counts = {th: 0 for th in iou_thresholds}\n", "BATCH_SIZE = 64\n", "# Create DataLoader for the entire dataset\n", "full_dataset = WidebandSignalDataset(signal_ids=train + validation + test, return_snr=True)\n", "full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)" ] }, { "cell_type": "markdown", "id": "6db6a18f", "metadata": {}, "source": [ "### Bounding Box" ] }, { "cell_type": "code", "execution_count": null, "id": "e396c72c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from collections import defaultdict\n", "import time\n", "from tqdm import tqdm\n", "import torch\n", "import torch.nn.functional as F\n", "from scipy.optimize import linear_sum_assignment\n", "\n", "def expand_true(array, distance=1):\n", " # Create kernel of appropriate size\n", " kernel = torch.ones((1, 1, distance * 2 + 1), device=array.device)\n", " array = array.unsqueeze(1).float() # Add channel dimension\n", " result = F.conv1d(array, kernel, padding=distance)\n", " result = result.squeeze(1) # Remove the extra dimension\n", " \n", " # Convert values greater than 0 to `True`\n", " return result > 0\n", "\n", "# Define supporting functions based on your friend's code\n", "def get_true_groups(tensor, device):\n", " assert tensor.dim() == 2, 'This function handles 2D tensor only'\n", " all_groups = []\n", " for i in range(tensor.size(0)):\n", " item = tensor[i]\n", " item = torch.cat([torch.tensor([False]).to(device), item, torch.tensor([False]).to(device)])\n", " diffs = item.float().diff()\n", " starts = (diffs == 1).nonzero(as_tuple=True)[0]\n", " ends = (diffs == -1).nonzero(as_tuple=True)[0] - 1\n", " groups = [(start.item(), end.item()) for start, end in zip(starts, ends)]\n", " all_groups.append(groups)\n", " return all_groups\n", "\n", "def get_target_boxes(metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n", " scale_ratio = number_of_bins / sample_rate\n", " targets = []\n", " masks = torch.zeros(number_of_bins)\n", " for meta in metadata:\n", " f, b = meta['position']\n", " x1, x2 = math.floor((f-b/2)*scale_ratio), math.ceil((f+b/2)*scale_ratio)\n", " masks[x1:x2] = 1\n", " targets.append((x1, x2))\n", " return targets, masks\n", "\n", "def get_target_boxes_batch(batch_metadata, number_of_bins, sample_rate=SAMPLE_RATE):\n", " all_targets, all_masks = [], []\n", " for metadata in batch_metadata:\n", " targets, masks = get_target_boxes(metadata, number_of_bins, sample_rate)\n", " all_targets.append(targets)\n", " all_masks.append(masks)\n", " return all_targets, all_masks\n", "\n", "def calculate_iou(box1, box2):\n", " intersection = max(0, min(box1[1], box2[1]) - max(box1[0], box2[0]))\n", " union = max(box1[1], box2[1]) - min(box1[0], box2[0])\n", " return intersection / union if union != 0 else 0\n", "\n", "def match_targets(targets, preds):\n", " ious = []\n", " for target in targets:\n", " iou_targets = []\n", " for pred in preds:\n", " iou_targets.append(calculate_iou(target, pred))\n", " ious.append(iou_targets)\n", " return linear_sum_assignment(ious, maximize=True)\n", "\n", "def match_targets_batch(batch_targets, batch_preds):\n", " all_assignments = []\n", " for targets, preds in zip(batch_targets, batch_preds):\n", " all_assignments.append(match_targets(targets, preds))\n", " return all_assignments\n", "\n", "def calculate_matched_ious(target_boxes, prediction_boxes, matching):\n", " ious = [0 for _ in target_boxes]\n", " matching_dict = dict(zip(*matching))\n", " for target_index, target_box in enumerate(target_boxes):\n", " if target_index in matching_dict:\n", " box1 = target_box\n", " box2 = prediction_boxes[matching_dict[target_index]]\n", " ious[target_index] = calculate_iou(box1, box2)\n", " return ious\n", "\n", "def calculate_matched_iou_mean_batch(batch_target_boxes, batch_pred_boxes, batch_matching):\n", " all_ious = []\n", " for args in zip(batch_target_boxes, batch_pred_boxes, batch_matching):\n", " all_ious.append(calculate_matched_ious(*args))\n", " return all_ious\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "24d483c1", "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "from tqdm import tqdm\n", "def model_predictor(signals):\n", " # Use the already loaded model and apply thresholding\n", " signals = reshape_to_2d(signals)\n", " outputs = model(signals)\n", " return expand_true(outputs.real > 0.5) # Use real part for thresholding\n", "def evaluate(predictor, data_loader, device=\"cuda\"):\n", " snr_metrics = defaultdict(lambda: {\n", " \"iou_sum\": 0.0,\n", " \"iou_count\": 0,\n", " \"recall_counts\": defaultdict(int),\n", " \"total_samples\": defaultdict(int),\n", " \"correct_pixels\": 0,\n", " \"total_pixels\": 0\n", " })\n", " total_iou_sum, total_iou_count = 0.0, 0\n", " total_correct_pixels, total_total_pixels = 0, 0\n", " total_recall_counts = defaultdict(int)\n", " total_samples = defaultdict(int)\n", "\n", " for inputs, masks, snrs_in_batch in tqdm(data_loader, desc=\"Evaluating\"):\n", " #inputs = inputs.to(device)\n", " inputs = reshape_to_2d(inputs).to(device)\n", " masks = masks.to(device)\n", " outputs = predictor(inputs)\n", "\n", " for i in range(len(snrs_in_batch)):\n", " snr = snrs_in_batch[i].item()\n", " mask = masks[i]\n", " output = outputs[i]\n", "\n", " # Ensure output matches mask shape\n", " if output.numel() != mask.numel():\n", " output = output.expand_as(mask) if output.numel() == 1 else output.reshape_as(mask)\n", "\n", " thresholded_output = (output.real >= 0.5).float()\n", "\n", " correct_pixels = (thresholded_output == mask).sum().item()\n", " total_pixels = mask.numel()\n", " snr_metrics[snr][\"correct_pixels\"] += correct_pixels\n", " snr_metrics[snr][\"total_pixels\"] += total_pixels\n", " total_correct_pixels += correct_pixels\n", " total_total_pixels += total_pixels\n", "\n", " target_boxes = get_true_groups(mask.unsqueeze(0), device=device)[0]\n", " pred_boxes = get_true_groups(thresholded_output.unsqueeze(0), device=device)[0]\n", " if not target_boxes or not pred_boxes:\n", " continue\n", " matching = match_targets(target_boxes, pred_boxes)\n", " matched_ious = calculate_matched_ious(target_boxes, pred_boxes, matching)\n", "\n", " snr_metrics[snr][\"iou_sum\"] += sum(matched_ious)\n", " snr_metrics[snr][\"iou_count\"] += len(matched_ious)\n", " total_iou_sum += sum(matched_ious)\n", " total_iou_count += len(matched_ious)\n", "\n", " for th in iou_thresholds:\n", " true_positives = sum(1 for iou in matched_ious if iou >= th)\n", " snr_metrics[snr][\"recall_counts\"][th] += true_positives\n", " snr_metrics[snr][\"total_samples\"][th] += len(target_boxes)\n", " total_recall_counts[th] += true_positives\n", " total_samples[th] += len(target_boxes)\n", "\n", " # Calculate overall metrics\n", " overall_accuracy = (total_correct_pixels / total_total_pixels) * 100 if total_total_pixels > 0 else 0\n", " overall_iou = total_iou_sum / total_iou_count if total_iou_count > 0 else 0\n", " overall_recall = {th: total_recall_counts[th] / total_samples[th] if total_samples[th] > 0 else 0 for th in iou_thresholds}\n", "\n", " # Print overall results\n", " print(f\"Overall Accuracy: {overall_accuracy:.2f}%\")\n", " print(f\"Overall IoU Score: {overall_iou:.4f}\")\n", " for th in iou_thresholds:\n", " print(f\"Recall at threshold {th}: {overall_recall[th]:.4f}\")\n", "\n", " # Print per-SNR results\n", " for snr, metrics in sorted(snr_metrics.items()):\n", " snr_accuracy = (metrics[\"correct_pixels\"] / metrics[\"total_pixels\"]) * 100 if metrics[\"total_pixels\"] > 0 else 0\n", " snr_iou = metrics[\"iou_sum\"] / metrics[\"iou_count\"] if metrics[\"iou_count\"] > 0 else 0\n", " print(f\"SNR: {snr} dB - Accuracy: {snr_accuracy:.2f}%\")\n", " print(f\" IoU: {snr_iou:.4f}\")\n", " for th in iou_thresholds:\n", " recall = metrics[\"recall_counts\"][th] / metrics[\"total_samples\"][th] if metrics[\"total_samples\"][th] > 0 else 0\n", " print(f\" Recall at threshold {th}: {recall:.4f}\")\n", "\n", " return snr_metrics\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a71c18ba", "metadata": { "scrolled": false }, "outputs": [], "source": [ "snr_metrics = evaluate(model_predictor, full_loader, device=device)" ] }, { "cell_type": "markdown", "id": "87417c7b", "metadata": {}, "source": [ "### Plot and Save" ] }, { "cell_type": "code", "execution_count": null, "id": "1dbfb5e6", "metadata": { "scrolled": false }, "outputs": [], "source": [ "import json\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "\n", "# Define the path for saving the JSON file and plots\n", "save_path = Path(\"CMuSeNet_plots/Synthetic\")\n", "save_path.mkdir(parents=True, exist_ok=True)\n", "json_file_path = save_path / \"evaluation_results.json\"\n", "\n", "# Save metrics and plot results\n", "def save_and_plot_results(snr_metrics, iou_thresholds):\n", " # Prepare data for plotting and JSON saving\n", " snr_values = sorted(snr_metrics.keys())\n", " iou_scores = [snr_metrics[snr][\"iou_sum\"] / snr_metrics[snr][\"iou_count\"] if snr_metrics[snr][\"iou_count\"] > 0 else 0 for snr in snr_values]\n", " accuracies = [(snr_metrics[snr][\"correct_pixels\"] / snr_metrics[snr][\"total_pixels\"]) * 100 if snr_metrics[snr][\"total_pixels\"] > 0 else 0 for snr in snr_values]\n", " recalls = {th: [(snr_metrics[snr][\"recall_counts\"][th] / snr_metrics[snr][\"total_samples\"][th]) if snr_metrics[snr][\"total_samples\"][th] > 0 else 0 for snr in snr_values] for th in iou_thresholds}\n", "\n", " # Save results to JSON\n", " results = {\n", " \"SNR\": snr_values,\n", " \"IoU_Scores\": iou_scores,\n", " \"Accuracy\": accuracies,\n", " \"Recall\": {str(th): recalls[th] for th in iou_thresholds}\n", " }\n", " with open(json_file_path, \"w\") as f:\n", " json.dump(results, f, indent=4)\n", " print(f\"Results saved to {json_file_path}\")\n", "\n", " # Plot IoU vs SNR\n", " plt.figure()\n", " plt.plot(snr_values, iou_scores, marker='o', label=\"IoU Score\")\n", " plt.xlabel(\"SNR (dB)\")\n", " plt.ylabel(\"IoU Score\")\n", " plt.title(\"IoU Score vs. SNR\")\n", " plt.grid(True)\n", " plt.legend()\n", " plt.savefig(save_path / \"IoU_vs_SNR.png\")\n", " plt.savefig(save_path / \"IoU_vs_SNR.svg\")\n", " plt.show()\n", "\n", " # Plot Accuracy vs SNR\n", " plt.figure()\n", " plt.plot(snr_values, accuracies, marker='o', label=\"Accuracy\")\n", " plt.xlabel(\"SNR (dB)\")\n", " plt.ylabel(\"Accuracy (%)\")\n", " plt.title(\"Accuracy vs. SNR (Threshold 0.5)\")\n", " plt.grid(True)\n", " plt.legend()\n", " plt.savefig(save_path / \"Accuracy_vs_SNR.png\")\n", " plt.savefig(save_path / \"Accuracy_vs_SNR.svg\")\n", " plt.show()\n", "\n", " # Plot Recall vs SNR for each threshold\n", " for th in iou_thresholds:\n", " plt.figure()\n", " plt.plot(snr_values, recalls[th], marker='o', label=f\"Recall at {th}\")\n", " plt.xlabel(\"SNR (dB)\")\n", " plt.ylabel(\"Recall\")\n", " plt.title(f\"Recall vs. SNR (Threshold {th})\")\n", " plt.grid(True)\n", " plt.legend()\n", " plt.savefig(save_path / f\"Recall_vs_SNR_{th}.png\")\n", " plt.savefig(save_path / f\"Recall_vs_SNR_{th}.svg\")\n", " plt.show()\n", "\n", "# Call this after running evaluate() to save and plot results\n", "save_and_plot_results(snr_metrics, iou_thresholds)" ] }, { "cell_type": "code", "execution_count": null, "id": "d0c0d3e8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }