File size: 11,723 Bytes
cbe30e6 facb3d5 cbe30e6 0c90d5f cbe30e6 4e938bd 687eaba cbe30e6 0c90d5f facb3d5 0c90d5f facb3d5 687eaba facb3d5 687eaba facb3d5 0c90d5f facb3d5 0c90d5f cbe30e6 facb3d5 687eaba cbe30e6 facb3d5 0c90d5f facb3d5 687eaba 0c90d5f 687eaba 0c90d5f 687eaba 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f 687eaba 0c90d5f facb3d5 687eaba 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 cbe30e6 facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 687eaba facb3d5 0c90d5f facb3d5 687eaba facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 0c90d5f facb3d5 687eaba facb3d5 0c90d5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import torch
from torch import nn
import logging
from typing import Tuple, List, Optional
from src.config.schemas import SystemConfig, ModelConfig
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
class BaseFortiTranEstimator(nn.Module):
"""
Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
This model performs channel estimation by:
1. Upsampling pilot symbols to full OFDM grid size (with linear layer)
2. Applying convolutional enhancement for subcarrier-symbol features
3. Converting to patch embeddings for transformer processing
4. Using transformer encoder to capture long-range dependencies
5. Reconstructing subcarrier-symbol representation and applying residual connections
6. Final convolutional refinement for high-quality channel estimates
"""
def __init__(self, system_config: SystemConfig, model_config: ModelConfig,
use_channel_adaptation: bool = False) -> None:
"""
Initialize the BaseFortiTranEstimator.
Args:
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
model_config: Model architecture configuration (patch size, layers, etc.)
use_channel_adaptation: Whether to enable channel adaptation features (disabled for FortiTran)
"""
super().__init__()
self.system_config = system_config
self.model_config = model_config
self.use_channel_adaptation = use_channel_adaptation
self.device = torch.device(model_config.device)
self.logger = logging.getLogger(self.__class__.__name__)
# Cache key dimensions for efficiency
self._setup_dimensions()
# Initialize model components
self._build_architecture()
# Move model to specified device
self.to(self.device)
self._log_initialization_info()
def _setup_dimensions(self) -> None:
"""Calculate and cache key dimensions from configuration."""
# OFDM grid dimensions
self.ofdm_size = (
self.system_config.ofdm.num_scs,
self.system_config.ofdm.num_symbols
)
# Pilot arrangement dimensions
self.pilot_size = (
self.system_config.pilot.num_scs,
self.system_config.pilot.num_symbols
)
# Feature dimensions for linear layers
self.pilot_features = self.pilot_size[0] * self.pilot_size[1]
self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1]
# Patch processing dimensions
self.patch_length = (
self.model_config.patch_size[0] * self.model_config.patch_size[1]
)
# Transformer input dimension (includes channel tokens if adaptation is enabled)
if self.use_channel_adaptation:
if self.model_config.adaptive_token_length is None:
raise ValueError("adaptive_token_length must be set when channel adaptation is enabled")
self.transformer_input_dim = self.patch_length + self.model_config.adaptive_token_length
else:
self.transformer_input_dim = self.patch_length
def _build_architecture(self) -> None:
"""Construct the model architecture components."""
# 1. Pilot-to-OFDM upsampling
self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
# 2. Initial convolutional enhancement
self.initial_enhancer = ConvEnhancer()
# 3. Patch embedding for transformer processing
self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
# 4. Channel adapter (conditional on use_channel_adaptation)
if self.use_channel_adaptation:
if self.model_config.channel_adaptivity_hidden_sizes is None:
raise ValueError("channel_adaptivity_hidden_sizes must be set when channel adaptation is enabled")
# Convert list to tuple as expected by ChannelAdapter (exactly 3 values)
hidden_sizes = tuple(self.model_config.channel_adaptivity_hidden_sizes)
if len(hidden_sizes) != 3:
raise ValueError("channel_adaptivity_hidden_sizes must have exactly 3 values")
self.channel_adapter = ChannelAdapter(hidden_sizes)
# 5. Transformer encoder for sequence modeling
transformer_output_dim = self.patch_length # Always output standard patch length
self.transformer_encoder = TransformerEncoderForChannels(
input_dim=self.transformer_input_dim,
output_dim=transformer_output_dim,
model_dim=self.model_config.model_dim,
num_head=self.model_config.num_head,
activation=self.model_config.activation,
dropout=self.model_config.dropout,
num_layers=self.model_config.num_layers,
max_len=self.model_config.max_seq_len,
pos_encoding_type=self.model_config.pos_encoding_type
)
# 6. Patch reconstruction
self.patch_reconstructor = InversePatchEmbedding(
self.ofdm_size,
self.model_config.patch_size
)
# 7. Final convolutional refinement
self.final_refiner = ConvEnhancer()
def _log_initialization_info(self) -> None:
"""Log model initialization details."""
adaptation_status = "enabled" if self.use_channel_adaptation else "disabled"
self.logger.info(f"{self.__class__.__name__} initialized successfully:")
self.logger.info(f" Channel adaptation: {adaptation_status}")
self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
self.logger.info(f" Patch size: {self.model_config.patch_size}")
self.logger.info(f" Model dimension: {self.model_config.model_dim}")
self.logger.info(f" Transformer layers: {self.model_config.num_layers}")
self.logger.info(f" Device: {self.device}")
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
self.logger.info(f" Total parameters: {total_params:,}")
self.logger.info(f" Trainable parameters: {trainable_params:,}")
def forward(self, pilot_symbols: torch.Tensor, meta_data: Optional[Tuple] = None) -> torch.Tensor:
"""
Forward pass for channel estimation.
Args:
pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
meta_data: Channel conditions (only used if channel adaptation is enabled)
Returns:
Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
"""
# Validate inputs based on adaptation mode
if self.use_channel_adaptation and meta_data is None:
raise ValueError("meta_data is required when channel adaptation is enabled")
if not self.use_channel_adaptation and meta_data is not None:
self.logger.warning("meta_data provided but channel adaptation is disabled - ignoring meta_data")
# Extract channel conditions if adaptation is enabled
channel_conditions = None
if self.use_channel_adaptation and meta_data is not None:
_, snr, delay_spread, max_dop_shift, _, _ = meta_data
channel_conditions = [
tensor.to(self.device)
for tensor in (snr, delay_spread, max_dop_shift)
]
# Ensure input is on correct device
pilot_symbols = pilot_symbols.to(self.device)
# Process real and imaginary parts separately
real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions)
imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions)
# Combine into complex tensor
channel_estimate = torch.complex(real_estimate, imag_estimate)
return channel_estimate
def _forward_real_valued(self, x: torch.Tensor,
channel_conditions: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
"""
Process real-valued input through the estimation pipeline.
Args:
x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
channel_conditions: Channel conditions for adaptation (optional)
Returns:
Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
"""
batch_size = x.shape[0]
# Flatten subcarrier and symbol dimensions for linear upsampling
if x.dim() > 2:
x = x.view(batch_size, -1)
# Stage 1: Upsample from pilot grid to OFDM grid
upsampled = self.pilot_upsampler(x)
# Reshape for convolutional processing
upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size)
# Stage 2: Initial convolutional enhancement
conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1)
# Stage 3: Convert to patch embeddings
patch_embeddings = self.patch_embedder(conv_enhanced)
# Stage 4: Apply channel adaptation if enabled
if self.use_channel_adaptation and channel_conditions is not None:
encoded_channel_condition = self.channel_adapter(*channel_conditions)
transformer_input = torch.cat((patch_embeddings, encoded_channel_condition), dim=2)
else:
transformer_input = patch_embeddings
# Stage 5: Transformer processing for long-range dependencies
transformer_output = self.transformer_encoder(transformer_input)
# Stage 6: Reconstruct subcarrier-symbol representation
reconstructed = self.patch_reconstructor(transformer_output)
# Stage 7: Apply residual connection
residual_combined = conv_enhanced + reconstructed
# Stage 8: Final convolutional refinement
refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
return refined_output
def get_model_info(self) -> dict:
"""Return model configuration and statistics."""
return {
'model_name': self.__class__.__name__,
'channel_adaptation': self.use_channel_adaptation,
'ofdm_size': self.ofdm_size,
'pilot_size': self.pilot_size,
'patch_size': self.model_config.patch_size,
'patch_length': self.patch_length,
'transformer_input_dim': self.transformer_input_dim,
'model_dim': self.model_config.model_dim,
'num_layers': self.model_config.num_layers,
'device': str(self.device),
'total_parameters': sum(p.numel() for p in self.parameters()),
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
}
class FortiTranEstimator(BaseFortiTranEstimator):
"""
Standard Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
This is the base version without channel adaptation features.
"""
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
"""
Initialize the FortiTranEstimator.
Args:
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
model_config: Model architecture configuration (patch size, layers, etc.)
"""
super().__init__(system_config, model_config, use_channel_adaptation=False)
|