File size: 2,247 Bytes
cbe30e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from torch import nn
from typing import Tuple
import torch


class ChannelAdapter(nn.Module):
    """Nonlinear encoder for channel condition tokens.

    Creates embeddings for SNR, delay spread, and Doppler shift parameters.
    Each embedding is conditioned on a single real value and is the output of an MLP
    """

    def __init__(self, hidden_sizes: Tuple[int, int, int]):
        """Initialize the token encoder modules.

        Args:
            hidden_sizes: Tuple of hidden layer dimensions (h1, h2, h3) for the MLP encoders
        """
        super().__init__()
        self.snr_encoder = self._create_mlp(hidden_sizes)
        self.ds_encoder = self._create_mlp(hidden_sizes)
        self.dop_encoder = self._create_mlp(hidden_sizes)

    @staticmethod
    def _create_mlp(hidden_sizes: Tuple[int, int, int]) -> nn.Sequential:
        """Create a multi-layer perceptron with specified dimensions.

        Args:
            hidden_sizes: Tuple of hidden layer dimensions (h1, h2, h3)

        Returns:
            Sequential MLP model with ReLU activations between linear layers
        """
        return nn.Sequential(
            nn.Linear(1, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], hidden_sizes[2])
        )

    def forward(
            self,
            snr: torch.Tensor,
            delay_spread: torch.Tensor,
            doppler_shift: torch.Tensor
    ) -> torch.Tensor:
        """Create token embeddings from channel conditions.

        Args:
            snr: Signal-to-Noise Ratio tensor of shape (batch_size, 1)
            delay_spread: Delay spread tensor of shape (batch_size, 1)
            doppler_shift: Doppler shift tensor of shape (batch_size, 1)

        Returns:
            Concatenated token embeddings of shape (batch_size, 3, 6)
        """
        batch_size = snr.shape[0]
        snr_emb = torch.reshape(self.snr_encoder(snr), (batch_size, -1, 2))
        ds_emb = torch.reshape(self.ds_encoder(delay_spread), (batch_size, -1, 2))
        dop_emb = torch.reshape(self.dop_encoder(doppler_shift), (batch_size, -1, 2))
        return torch.cat((snr_emb, ds_emb, dop_emb), dim=2)