| """
|
| ResNet-style 1D CNN for radio modulation classification.
|
| Input: (batch, 2, seq_len) — I/Q channels.
|
| Output: (batch, num_classes) logits.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from typing import Optional
|
|
|
|
|
| class ResidualBlock1d(nn.Module):
|
| def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3):
|
| super().__init__()
|
| self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, padding=kernel_size // 2)
|
| self.bn1 = nn.BatchNorm1d(out_ch)
|
| self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, padding=kernel_size // 2)
|
| self.bn2 = nn.BatchNorm1d(out_ch)
|
| self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| out = torch.relu(self.bn1(self.conv1(x)))
|
| out = self.bn2(self.conv2(out))
|
| out += self.downsample(x)
|
| return torch.relu(out)
|
|
|
|
|
| class ModulationClassifier(nn.Module):
|
| def __init__(
|
| self,
|
| num_classes: int = 10,
|
| seq_len: int = 128,
|
| base_channels: int = 32,
|
| num_blocks: int = 3,
|
| ):
|
| super().__init__()
|
| self.num_classes = num_classes
|
| self.seq_len = seq_len
|
|
|
| self.stem = nn.Sequential(
|
| nn.Conv1d(2, base_channels, 7, padding=3),
|
| nn.BatchNorm1d(base_channels),
|
| nn.ReLU(inplace=True),
|
| nn.MaxPool1d(2),
|
| )
|
|
|
| channels = [base_channels * (2 ** i) for i in range(num_blocks)]
|
| layers = []
|
| in_ch = base_channels
|
| for ch in channels:
|
| layers.append(ResidualBlock1d(in_ch, ch))
|
| in_ch = ch
|
| self.blocks = nn.Sequential(*layers)
|
|
|
|
|
| self.pool = nn.AdaptiveAvgPool1d(1)
|
| self.fc = nn.Linear(channels[-1], num_classes)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| x = self.stem(x)
|
| x = self.blocks(x)
|
| x = self.pool(x)
|
| x = x.flatten(1)
|
| return self.fc(x)
|
|
|