|
|
from torch import nn |
|
|
from typing import Tuple |
|
|
import torch |
|
|
|
|
|
|
|
|
class ChannelAdapter(nn.Module): |
|
|
"""Nonlinear encoder for channel condition tokens. |
|
|
|
|
|
Creates embeddings for SNR, delay spread, and Doppler shift parameters. |
|
|
Each embedding is conditioned on a single real value and is the output of an MLP |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_sizes: Tuple[int, int, int]): |
|
|
"""Initialize the token encoder modules. |
|
|
|
|
|
Args: |
|
|
hidden_sizes: Tuple of hidden layer dimensions (h1, h2, h3) for the MLP encoders |
|
|
""" |
|
|
super().__init__() |
|
|
self.snr_encoder = self._create_mlp(hidden_sizes) |
|
|
self.ds_encoder = self._create_mlp(hidden_sizes) |
|
|
self.dop_encoder = self._create_mlp(hidden_sizes) |
|
|
|
|
|
@staticmethod |
|
|
def _create_mlp(hidden_sizes: Tuple[int, int, int]) -> nn.Sequential: |
|
|
"""Create a multi-layer perceptron with specified dimensions. |
|
|
|
|
|
Args: |
|
|
hidden_sizes: Tuple of hidden layer dimensions (h1, h2, h3) |
|
|
|
|
|
Returns: |
|
|
Sequential MLP model with ReLU activations between linear layers |
|
|
""" |
|
|
return nn.Sequential( |
|
|
nn.Linear(1, hidden_sizes[0]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_sizes[0], hidden_sizes[1]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_sizes[1], hidden_sizes[2]) |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
snr: torch.Tensor, |
|
|
delay_spread: torch.Tensor, |
|
|
doppler_shift: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Create token embeddings from channel conditions. |
|
|
|
|
|
Args: |
|
|
snr: Signal-to-Noise Ratio tensor of shape (batch_size, 1) |
|
|
delay_spread: Delay spread tensor of shape (batch_size, 1) |
|
|
doppler_shift: Doppler shift tensor of shape (batch_size, 1) |
|
|
|
|
|
Returns: |
|
|
Concatenated token embeddings of shape (batch_size, 3, 6) |
|
|
""" |
|
|
batch_size = snr.shape[0] |
|
|
snr_emb = torch.reshape(self.snr_encoder(snr), (batch_size, -1, 2)) |
|
|
ds_emb = torch.reshape(self.ds_encoder(delay_spread), (batch_size, -1, 2)) |
|
|
dop_emb = torch.reshape(self.dop_encoder(doppler_shift), (batch_size, -1, 2)) |
|
|
return torch.cat((snr_emb, ds_emb, dop_emb), dim=2) |
|
|
|