from torch import nn import torch from .positional_encodings import LearnablePositionalEncoding, SinusoidalPositionalEncoding class TransformerEncoderForChannels(nn.Module): """Transformer encoder for channels""" def __init__( self, input_dim: int, output_dim: int, model_dim: int = 128, num_head: int = 4, activation: str = "gelu", dropout: float = 0.1, num_layers: int = 3, max_len: int = 512, pos_encoding_type: str = "learnable" ): """Initialize the encoder. Args: input_dim: Input dimension output_dim: Output dimension model_dim: Model dimension num_head: Number of attention heads activation: Activation function name dropout: Dropout rate num_layers: Number of transformer layers max_len: Maximum sequence length pos_encoding_type: Type of positional encoding """ super().__init__() self.linear_1 = nn.Linear(input_dim, model_dim) if pos_encoding_type == "learnable": self.positional_encoding = LearnablePositionalEncoding(max_len, model_dim) elif pos_encoding_type == "sinusoidal": self.positional_encoding = SinusoidalPositionalEncoding(max_len, model_dim) else: raise ValueError("pos_encoding_type must be 'learnable' or 'sinusoidal'") transformer_layer = nn.TransformerEncoderLayer( d_model=model_dim, nhead=num_head, dim_feedforward=2 * model_dim, activation=activation, dropout=dropout, batch_first=True ) self.transformer = nn.TransformerEncoder( transformer_layer, num_layers=num_layers ) self.linear_2 = nn.Linear(model_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Process input through the encoder. Args: x: Input tensor (batch_size, seq_length, input_dim) Returns: Processed tensor (batch_size, seq_length, output_dim) """ x = self.linear_1(x) x = self.positional_encoding(x) x = self.transformer(x) return self.linear_2(x)