BerkIGuler's picture
added model blocks
cbe30e6
from torch import nn
import torch
from .positional_encodings import LearnablePositionalEncoding, SinusoidalPositionalEncoding
class TransformerEncoderForChannels(nn.Module):
"""Transformer encoder for channels"""
def __init__(
self,
input_dim: int,
output_dim: int,
model_dim: int = 128,
num_head: int = 4,
activation: str = "gelu",
dropout: float = 0.1,
num_layers: int = 3,
max_len: int = 512,
pos_encoding_type: str = "learnable"
):
"""Initialize the encoder.
Args:
input_dim: Input dimension
output_dim: Output dimension
model_dim: Model dimension
num_head: Number of attention heads
activation: Activation function name
dropout: Dropout rate
num_layers: Number of transformer layers
max_len: Maximum sequence length
pos_encoding_type: Type of positional encoding
"""
super().__init__()
self.linear_1 = nn.Linear(input_dim, model_dim)
if pos_encoding_type == "learnable":
self.positional_encoding = LearnablePositionalEncoding(max_len, model_dim)
elif pos_encoding_type == "sinusoidal":
self.positional_encoding = SinusoidalPositionalEncoding(max_len, model_dim)
else:
raise ValueError("pos_encoding_type must be 'learnable' or 'sinusoidal'")
transformer_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_head,
dim_feedforward=2 * model_dim,
activation=activation,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
transformer_layer,
num_layers=num_layers
)
self.linear_2 = nn.Linear(model_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through the encoder.
Args:
x: Input tensor (batch_size, seq_length, input_dim)
Returns:
Processed tensor (batch_size, seq_length, output_dim)
"""
x = self.linear_1(x)
x = self.positional_encoding(x)
x = self.transformer(x)
return self.linear_2(x)