BerkIGuler commited on
Commit
0c90d5f
·
1 Parent(s): b57f75f

combined adafortitran and fortitran with basefortitran to minimize code repetition. Added dataset.py

Browse files
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  pydantic
3
- yaml
 
 
1
  torch
2
  pydantic
3
+ yaml
4
+ scipy
src/config/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from src.config.schemas import ModelConfig, SystemConfig
src/data/__init__.py ADDED
File without changes
src/data/dataset.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for loading and processing .mat files containing channel estimates for PyTorch."""
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Callable, List, Optional, Tuple, Union
5
+
6
+ import scipy.io as sio
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+ from src.utils import extract_values
11
+
12
+ __all__ = ['MatDataset', 'get_test_dataloaders']
13
+
14
+
15
+ @dataclass
16
+ class PilotDimensions:
17
+ """Container for pilot signal dimensions.
18
+
19
+ Stores and validates the dimensions of pilot signals used in channel estimation.
20
+
21
+ Attributes:
22
+ num_subcarriers: Number of subcarriers in the pilot signal
23
+ num_ofdm_symbols: Number of OFDM symbols in the pilot signal
24
+ """
25
+ num_subcarriers: int
26
+ num_ofdm_symbols: int
27
+
28
+ def __post_init__(self):
29
+ """Validate dimensions after initialization.
30
+
31
+ Raises:
32
+ ValueError: If either dimension is not a positive integer
33
+ """
34
+ if self.num_subcarriers <= 0 or self.num_ofdm_symbols <= 0:
35
+ raise ValueError("Pilot dimensions must be positive integers")
36
+
37
+ def as_tuple(self) -> Tuple[int, int]:
38
+ """Return dimensions as a tuple.
39
+
40
+ Returns:
41
+ Tuple of (num_subcarriers, num_ofdm_symbols)
42
+ """
43
+ return self.num_subcarriers, self.num_ofdm_symbols
44
+
45
+
46
+ class MatDataset(Dataset):
47
+ """Dataset for loading and formatting .mat files containing channel estimates.
48
+
49
+ Processes .mat files containing channel estimation data and converts them into
50
+ PyTorch complex tensors for channel estimation tasks.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ data_dir: Union[str, Path],
56
+ pilot_dims: List[int],
57
+ transform: Optional[Callable] = None
58
+ ) -> None:
59
+ """Initialize the MatDataset.
60
+
61
+ Args:
62
+ data_dir: Path to the directory containing the dataset.
63
+ pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
64
+ transform: Optional transformation to apply to samples.
65
+
66
+ Raises:
67
+ ValueError: If pilot dimensions are invalid.
68
+ FileNotFoundError: If data_dir doesn't exist.
69
+ """
70
+ self.data_dir = Path(data_dir)
71
+ self.pilot_dims = PilotDimensions(*pilot_dims)
72
+ self.transform = transform
73
+
74
+ if not self.data_dir.exists():
75
+ raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
76
+
77
+ self.file_list = list(self.data_dir.glob("*.mat"))
78
+ if not self.file_list:
79
+ raise ValueError(f"No .mat files found in {self.data_dir}")
80
+
81
+ def __len__(self) -> int:
82
+ """Return the total number of files in the dataset.
83
+
84
+ Returns:
85
+ Integer count of .mat files in the dataset directory
86
+ """
87
+ return len(self.file_list)
88
+
89
+ def _process_channel_data(
90
+ self,
91
+ h_ideal: torch.Tensor,
92
+ mat_data: dict
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """Process channel data and extract pilot values from LS estimates.
95
+
96
+ Extracts pilot values from LS channel estimates with zero entries removed,
97
+ returning complex-valued tensors for both estimate and ground truth.
98
+
99
+ Args:
100
+ h_ideal: Ground truth channel tensor
101
+ mat_data: Loaded .mat file data
102
+
103
+ Returns:
104
+ Tuple of (pilot LS estimate, ground truth channel)
105
+
106
+ Raises:
107
+ ValueError: If the data format is unexpected or processing fails
108
+ """
109
+ try:
110
+ # Extract LS channel estimate with zero entries
111
+ hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
112
+
113
+ # Remove zero entries, keep only pilot values
114
+ zero_complex = torch.complex(torch.tensor(0.0), torch.tensor(0.0))
115
+ hp_ls = hzero_ls[hzero_ls != zero_complex]
116
+
117
+ # Validate expected number of pilot values
118
+ expected_pilots = self.pilot_dims.num_subcarriers * self.pilot_dims.num_ofdm_symbols
119
+ if hp_ls.numel() != expected_pilots:
120
+ raise ValueError(
121
+ f"Expected {expected_pilots} pilot values, got {hp_ls.numel()}"
122
+ )
123
+
124
+ # Reshape to pilot grid dimensions [subcarriers, symbols]
125
+ hp_ls = hp_ls.unsqueeze(dim=1).view(
126
+ self.pilot_dims.num_ofdm_symbols,
127
+ self.pilot_dims.num_subcarriers
128
+ ).t()
129
+
130
+ return hp_ls, h_ideal
131
+
132
+ except Exception as e:
133
+ raise ValueError(f"Error processing channel data: {str(e)}")
134
+
135
+ def __getitem__(
136
+ self,
137
+ idx: int
138
+ ) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
139
+ """Load and process a .mat file at the given index.
140
+
141
+ Args:
142
+ idx: Index of the file to load.
143
+
144
+ Returns:
145
+ Tuple containing:
146
+ - Pilot LS channel estimate (complex tensor)
147
+ - Ground truth channel estimate (complex tensor)
148
+ - Metadata extracted from filename
149
+
150
+ Raises:
151
+ ValueError: If file format is invalid or processing fails.
152
+ IndexError: If idx is out of range.
153
+ """
154
+ if not 0 <= idx < len(self):
155
+ raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
156
+
157
+ try:
158
+ # Load .mat file
159
+ mat_data = sio.loadmat(self.file_list[idx])
160
+ if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
161
+ raise ValueError("Invalid .mat file format: missing required data")
162
+
163
+ # Extract ground truth channel
164
+ h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
165
+
166
+ # Process channel data to extract pilot estimates
167
+ h_est, h_ideal = self._process_channel_data(h_ideal, mat_data)
168
+
169
+ # Extract metadata from filename
170
+ meta_data = extract_values(self.file_list[idx].name)
171
+ if meta_data is None:
172
+ raise ValueError(f"Unrecognized filename format: {self.file_list[idx].name}")
173
+
174
+ # Apply optional transforms
175
+ if self.transform:
176
+ h_est = self.transform(h_est)
177
+ h_ideal = self.transform(h_ideal)
178
+
179
+ return h_est, h_ideal, meta_data
180
+
181
+ except Exception as e:
182
+ raise ValueError(f"Error processing file {self.file_list[idx]}: {str(e)}")
183
+
184
+
185
+ def get_test_dataloaders(
186
+ dataset_dir: Union[str, Path],
187
+ params: dict
188
+ ) -> List[Tuple[str, DataLoader]]:
189
+ """Create DataLoaders for each subdirectory in the dataset directory.
190
+
191
+ Automatically discovers and creates appropriate DataLoader instances for
192
+ all subdirectories in the specified dataset directory, useful for testing
193
+ across multiple test conditions or scenarios.
194
+
195
+ Args:
196
+ dataset_dir: Path to main directory containing dataset subdirectories
197
+ params: Configuration parameters including:
198
+ - pilot_dims: List of [num_subcarriers, num_ofdm_symbols]
199
+ - batch_size: Number of samples per batch
200
+
201
+ Returns:
202
+ List of tuples containing (subdirectory_name, corresponding_dataloader)
203
+
204
+ Raises:
205
+ FileNotFoundError: If dataset_dir doesn't exist
206
+ ValueError: If params are invalid or no valid subdirectories are found
207
+ """
208
+ dataset_dir = Path(dataset_dir)
209
+ if not dataset_dir.exists():
210
+ raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
211
+
212
+ if not isinstance(params, dict) or "pilot_dims" not in params or "batch_size" not in params:
213
+ raise ValueError("params must be a dict containing 'pilot_dims' and 'batch_size'")
214
+
215
+ subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
216
+ if not subdirs:
217
+ raise ValueError(f"No subdirectories found in {dataset_dir}")
218
+
219
+ test_datasets = [
220
+ (
221
+ subdir.name,
222
+ MatDataset(
223
+ subdir,
224
+ params["pilot_dims"]
225
+ )
226
+ )
227
+ for subdir in subdirs
228
+ ]
229
+
230
+ return [
231
+ (name, DataLoader(
232
+ dataset,
233
+ batch_size=params["batch_size"],
234
+ shuffle=False,
235
+ num_workers=0
236
+ ))
237
+ for name, dataset in test_datasets
238
+ ]
src/models/adafortitran.py CHANGED
@@ -1,25 +1,14 @@
1
- import torch
2
- from torch import nn
3
- import logging
4
- from typing import Tuple, List
5
 
6
- from src.config.schemas import SystemConfig, ModelConfig
7
- from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
8
-
9
-
10
- class AdaFortiTranEstimator(nn.Module):
11
 
 
12
  """
13
- Hybrid CNN-Transformer Channel Estimator for OFDM Systems with channel adaptation.
14
 
15
- This model performs channel estimation by:
16
- 1. Upsampling pilot symbols to full OFDM grid size
17
- 2. Applying convolutional enhancement for spatial features
18
- 3. Converting to patch embeddings for transformer processing
19
- 4. Concatenating channel statistics priors to channel patches
20
- 5. Using transformer encoder to capture long-range dependencies
21
- 6. Reconstructing spatial representation and applying residual connections
22
- 7. Final convolutional refinement for high-quality channel estimates
23
  """
24
 
25
  def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
@@ -30,187 +19,4 @@ class AdaFortiTranEstimator(nn.Module):
30
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
31
  model_config: Model architecture configuration (patch size, layers, etc.)
32
  """
33
- super().__init__()
34
-
35
- self.system_config = system_config
36
- self.model_config = model_config
37
- self.device = torch.device(model_config.device)
38
- self.logger = logging.getLogger(self.__class__.__name__)
39
-
40
- # Cache key dimensions for efficiency
41
- self._setup_dimensions()
42
-
43
- # Initialize model components
44
- self._build_architecture()
45
-
46
- # Move model to specified device
47
- self.to(self.device)
48
-
49
- self._log_initialization_info()
50
-
51
- def _setup_dimensions(self) -> None:
52
- """Calculate and cache key dimensions from configuration."""
53
- # OFDM grid dimensions
54
- self.ofdm_size = (
55
- self.system_config.ofdm.num_scs,
56
- self.system_config.ofdm.num_symbols
57
- )
58
-
59
- # Pilot arrangement dimensions
60
- self.pilot_size = (
61
- self.system_config.pilot.num_scs,
62
- self.system_config.pilot.num_symbols
63
- )
64
-
65
- # Feature dimensions for linear layers
66
- self.pilot_features = self.pilot_size[0] * self.pilot_size[1]
67
- self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1]
68
-
69
- # Patch processing dimensions
70
- self.patch_length = (
71
- self.model_config.patch_size[0] * self.model_config.patch_size[1]
72
- )
73
-
74
- self.adaptive_patch_length = self.patch_length + self.model_config.adaptive_token_length
75
-
76
- def _build_architecture(self) -> None:
77
- """Construct the model architecture components."""
78
- # 1. Pilot-to-OFDM upsampling
79
- self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
80
- # 2. Initial convolutional enhancement
81
- self.initial_enhancer = ConvEnhancer()
82
-
83
- # 3. Patch embedding for transformer processing
84
- self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
85
-
86
- # 4. Channel adapter for conditional attention
87
- self.channel_adapter = ChannelAdapter(self.model_config.channel_adaptivity_hidden_sizes)
88
-
89
- # 5. Transformer encoder for sequence modeling
90
- self.transformer_encoder = TransformerEncoderForChannels(
91
- input_dim=self.adaptive_patch_length,
92
- output_dim=self.patch_length,
93
- model_dim=self.model_config.model_dim,
94
- num_head=self.model_config.num_head,
95
- activation=self.model_config.activation,
96
- dropout=self.model_config.dropout,
97
- num_layers=self.model_config.num_layers,
98
- max_len=self.model_config.max_seq_len,
99
- pos_encoding_type=self.model_config.pos_encoding_type
100
- )
101
-
102
- # 6. Patch reconstruction
103
- self.patch_reconstructor = InversePatchEmbedding(
104
- self.ofdm_size,
105
- self.model_config.patch_size
106
- )
107
-
108
- # 7. Final convolutional refinement
109
- self.final_refiner = ConvEnhancer()
110
-
111
- def _log_initialization_info(self) -> None:
112
- """Log model initialization details."""
113
- self.logger.info("AdaFortiTranEstimator initialized successfully:")
114
- self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
115
- self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
116
- self.logger.info(f" Patch size: {self.model_config.patch_size}")
117
- self.logger.info(f" Model dimension: {self.model_config.model_dim}")
118
- self.logger.info(f" Transformer layers: {self.model_config.num_layers}")
119
- self.logger.info(f" Device: {self.device}")
120
-
121
- total_params = sum(p.numel() for p in self.parameters())
122
- trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
123
- self.logger.info(f" Total parameters: {total_params:,}")
124
- self.logger.info(f" Trainable parameters: {trainable_params:,}")
125
-
126
- def forward(self, pilot_symbols: torch.Tensor, meta_data: Tuple) -> torch.Tensor:
127
- """
128
- Forward pass for channel estimation.
129
-
130
- Args:
131
- pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
132
- meta_data: TODO: Add complete type annotation.
133
-
134
- Returns:
135
- Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
136
- """
137
-
138
- # Extract and move channel conditions to device
139
- _, snr, delay_spread, max_dop_shift, _, _ = meta_data
140
- channel_conditions = [
141
- tensor.to(self.device)
142
- for tensor in (snr, delay_spread, max_dop_shift)
143
- ]
144
-
145
- # Ensure input is on correct device
146
- pilot_symbols = pilot_symbols.to(self.device)
147
-
148
- # Process real and imaginary parts separately
149
- real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions)
150
- imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions)
151
-
152
- # Combine into complex tensor
153
- channel_estimate = torch.complex(real_estimate, imag_estimate)
154
-
155
- return channel_estimate
156
-
157
- def _forward_real_valued(self, x: torch.Tensor, channel_conditions: List[torch.Tensor]) -> torch.Tensor:
158
- """
159
- Process real-valued input through the estimation pipeline.
160
-
161
- Args:
162
- x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
163
-
164
- Returns:
165
- Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
166
- """
167
- batch_size = x.shape[0]
168
-
169
- # Flatten spatial dimensions for linear upsampling
170
- if x.dim() > 2:
171
- x = x.view(batch_size, -1)
172
-
173
- # Stage 1: Upsample from pilot grid to OFDM grid
174
- upsampled = self.pilot_upsampler(x)
175
-
176
- # Reshape for convolutional processing
177
- upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size)
178
-
179
- # Stage 2: Initial convolutional enhancement
180
- conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1)
181
-
182
- # Stage 3: Convert to patch embeddings
183
- patch_embeddings = self.patch_embedder(conv_enhanced)
184
-
185
- # Stage 4: Get conditioned channel encodings
186
- encoded_channel_condition = self.channel_adapter(*channel_conditions)
187
- conditioned_channel_encodings = torch.cat((patch_embeddings, encoded_channel_condition), dim=2)
188
-
189
- # Stage 5: Transformer processing for long-range dependencies
190
- transformer_output = self.transformer_encoder(conditioned_channel_encodings)
191
-
192
- # Stage 6: Reconstruct spatial representation
193
- reconstructed = self.patch_reconstructor(transformer_output)
194
-
195
- # Stage 7: Apply residual connection
196
- residual_combined = conv_enhanced + reconstructed
197
-
198
- # Stage 8: Final convolutional refinement
199
- refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
200
-
201
- return refined_output
202
-
203
- def get_model_info(self) -> dict:
204
- """Return model configuration and statistics."""
205
- return {
206
- 'model_name': self.__class__.__name__,
207
- 'ofdm_size': self.ofdm_size,
208
- 'pilot_size': self.pilot_size,
209
- 'patch_size': self.model_config.patch_size,
210
- 'patch_length': self.patch_length,
211
- 'model_dim': self.model_config.model_dim,
212
- 'num_layers': self.model_config.num_layers,
213
- 'device': str(self.device),
214
- 'total_parameters': sum(p.numel() for p in self.parameters()),
215
- 'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
216
- }
 
1
+ from .fortitran import BaseFortiTranEstimator
2
+ from src.config import SystemConfig, ModelConfig
 
 
3
 
 
 
 
 
 
4
 
5
+ class AdaFortiTranEstimator(BaseFortiTranEstimator):
6
  """
7
+ Adaptive Hybrid CNN-Transformer Channel Estimator for OFDM Systems with channel adaptation.
8
 
9
+ This model extends the base estimator with channel adaptation capabilities,
10
+ incorporating channel conditions (SNR, delay spread, Doppler shift) into
11
+ the estimation process through conditional attention mechanisms.
 
 
 
 
 
12
  """
13
 
14
  def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
 
19
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
20
  model_config: Model architecture configuration (patch size, layers, etc.)
21
  """
22
+ super().__init__(system_config, model_config, use_channel_adaptation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/fortitran.py CHANGED
@@ -1,14 +1,16 @@
1
  import torch
2
  from torch import nn
3
  import logging
 
4
 
5
- from src.config.schemas import SystemConfig, ModelConfig
6
- from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels
 
7
 
8
 
9
- class FortiTranEstimator(nn.Module):
10
  """
11
- Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
12
 
13
  This model performs channel estimation by:
14
  1. Upsampling pilot symbols to full OFDM grid size
@@ -19,18 +21,21 @@ class FortiTranEstimator(nn.Module):
19
  6. Final convolutional refinement for high-quality channel estimates
20
  """
21
 
22
- def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
 
23
  """
24
- Initialize the FortiTranEstimator.
25
 
26
  Args:
27
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
28
  model_config: Model architecture configuration (patch size, layers, etc.)
 
29
  """
30
  super().__init__()
31
 
32
  self.system_config = system_config
33
  self.model_config = model_config
 
34
  self.device = torch.device(model_config.device)
35
  self.logger = logging.getLogger(self.__class__.__name__)
36
 
@@ -68,41 +73,57 @@ class FortiTranEstimator(nn.Module):
68
  self.model_config.patch_size[0] * self.model_config.patch_size[1]
69
  )
70
 
 
 
 
 
 
 
71
  def _build_architecture(self) -> None:
72
  """Construct the model architecture components."""
73
  # 1. Pilot-to-OFDM upsampling
74
  self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
 
75
  # 2. Initial convolutional enhancement
76
  self.initial_enhancer = ConvEnhancer()
77
 
78
  # 3. Patch embedding for transformer processing
79
  self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
80
 
81
- # 4. Transformer encoder for sequence modeling
 
 
 
 
 
 
 
82
  self.transformer_encoder = TransformerEncoderForChannels(
83
- input_dim=self.patch_length,
84
- output_dim=self.patch_length,
85
  model_dim=self.model_config.model_dim,
86
  num_head=self.model_config.num_head,
87
  activation=self.model_config.activation,
88
  dropout=self.model_config.dropout,
89
  num_layers=self.model_config.num_layers,
90
  max_len=self.model_config.max_seq_len,
91
- pos_encoding_type=self.model_config.pos_encoding_type,
92
  )
93
 
94
- # 5. Patch reconstruction
95
  self.patch_reconstructor = InversePatchEmbedding(
96
  self.ofdm_size,
97
  self.model_config.patch_size
98
  )
99
 
100
- # 6. Final convolutional refinement
101
  self.final_refiner = ConvEnhancer()
102
 
103
  def _log_initialization_info(self) -> None:
104
  """Log model initialization details."""
105
- self.logger.info("FortiTranEstimator initialized successfully:")
 
 
106
  self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
107
  self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
108
  self.logger.info(f" Patch size: {self.model_config.patch_size}")
@@ -115,34 +136,53 @@ class FortiTranEstimator(nn.Module):
115
  self.logger.info(f" Total parameters: {total_params:,}")
116
  self.logger.info(f" Trainable parameters: {trainable_params:,}")
117
 
118
- def forward(self, pilot_symbols: torch.Tensor) -> torch.Tensor:
119
  """
120
  Forward pass for channel estimation.
121
 
122
  Args:
123
  pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
 
124
 
125
  Returns:
126
  Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
127
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Ensure input is on correct device
129
  pilot_symbols = pilot_symbols.to(self.device)
130
 
131
  # Process real and imaginary parts separately
132
- real_estimate = self._forward_real_valued(pilot_symbols.real)
133
- imag_estimate = self._forward_real_valued(pilot_symbols.imag)
134
 
135
  # Combine into complex tensor
136
  channel_estimate = torch.complex(real_estimate, imag_estimate)
137
 
138
  return channel_estimate
139
 
140
- def _forward_real_valued(self, x: torch.Tensor) -> torch.Tensor:
 
141
  """
142
  Process real-valued input through the estimation pipeline.
143
 
144
  Args:
145
  x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
 
146
 
147
  Returns:
148
  Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
@@ -165,16 +205,23 @@ class FortiTranEstimator(nn.Module):
165
  # Stage 3: Convert to patch embeddings
166
  patch_embeddings = self.patch_embedder(conv_enhanced)
167
 
168
- # Stage 4: Transformer processing for long-range dependencies
169
- transformer_output = self.transformer_encoder(patch_embeddings)
 
 
 
 
 
 
 
170
 
171
- # Stage 5: Reconstruct spatial representation
172
  reconstructed = self.patch_reconstructor(transformer_output)
173
 
174
- # Stage 6: Apply residual connection
175
  residual_combined = conv_enhanced + reconstructed
176
 
177
- # Stage 7: Final convolutional refinement
178
  refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
179
 
180
  return refined_output
@@ -183,13 +230,33 @@ class FortiTranEstimator(nn.Module):
183
  """Return model configuration and statistics."""
184
  return {
185
  'model_name': self.__class__.__name__,
 
186
  'ofdm_size': self.ofdm_size,
187
  'pilot_size': self.pilot_size,
188
  'patch_size': self.model_config.patch_size,
189
  'patch_length': self.patch_length,
 
190
  'model_dim': self.model_config.model_dim,
191
  'num_layers': self.model_config.num_layers,
192
  'device': str(self.device),
193
  'total_parameters': sum(p.numel() for p in self.parameters()),
194
  'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
195
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
  import logging
4
+ from typing import Tuple, List, Optional
5
 
6
+ from src.config import SystemConfig, ModelConfig
7
+ from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
8
+ ChannelAdapter
9
 
10
 
11
+ class BaseFortiTranEstimator(nn.Module):
12
  """
13
+ Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
14
 
15
  This model performs channel estimation by:
16
  1. Upsampling pilot symbols to full OFDM grid size
 
21
  6. Final convolutional refinement for high-quality channel estimates
22
  """
23
 
24
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig,
25
+ use_channel_adaptation: bool = False) -> None:
26
  """
27
+ Initialize the BaseFortiTranEstimator.
28
 
29
  Args:
30
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
31
  model_config: Model architecture configuration (patch size, layers, etc.)
32
+ use_channel_adaptation: Whether to enable channel adaptation features
33
  """
34
  super().__init__()
35
 
36
  self.system_config = system_config
37
  self.model_config = model_config
38
+ self.use_channel_adaptation = use_channel_adaptation
39
  self.device = torch.device(model_config.device)
40
  self.logger = logging.getLogger(self.__class__.__name__)
41
 
 
73
  self.model_config.patch_size[0] * self.model_config.patch_size[1]
74
  )
75
 
76
+ # Adaptive patch length (only used if channel adaptation is enabled)
77
+ if self.use_channel_adaptation:
78
+ self.adaptive_patch_length = self.patch_length + self.model_config.adaptive_token_length
79
+ else:
80
+ self.adaptive_patch_length = self.patch_length
81
+
82
  def _build_architecture(self) -> None:
83
  """Construct the model architecture components."""
84
  # 1. Pilot-to-OFDM upsampling
85
  self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
86
+
87
  # 2. Initial convolutional enhancement
88
  self.initial_enhancer = ConvEnhancer()
89
 
90
  # 3. Patch embedding for transformer processing
91
  self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
92
 
93
+ # 4. Channel adapter (conditional on use_channel_adaptation)
94
+ if self.use_channel_adaptation:
95
+ self.channel_adapter = ChannelAdapter(self.model_config.channel_adaptivity_hidden_sizes)
96
+
97
+ # 5. Transformer encoder for sequence modeling
98
+ transformer_input_dim = self.adaptive_patch_length if self.use_channel_adaptation else self.patch_length
99
+ transformer_output_dim = self.patch_length # Always output standard patch length
100
+
101
  self.transformer_encoder = TransformerEncoderForChannels(
102
+ input_dim=transformer_input_dim,
103
+ output_dim=transformer_output_dim,
104
  model_dim=self.model_config.model_dim,
105
  num_head=self.model_config.num_head,
106
  activation=self.model_config.activation,
107
  dropout=self.model_config.dropout,
108
  num_layers=self.model_config.num_layers,
109
  max_len=self.model_config.max_seq_len,
110
+ pos_encoding_type=self.model_config.pos_encoding_type
111
  )
112
 
113
+ # 6. Patch reconstruction
114
  self.patch_reconstructor = InversePatchEmbedding(
115
  self.ofdm_size,
116
  self.model_config.patch_size
117
  )
118
 
119
+ # 7. Final convolutional refinement
120
  self.final_refiner = ConvEnhancer()
121
 
122
  def _log_initialization_info(self) -> None:
123
  """Log model initialization details."""
124
+ adaptation_status = "enabled" if self.use_channel_adaptation else "disabled"
125
+ self.logger.info(f"{self.__class__.__name__} initialized successfully:")
126
+ self.logger.info(f" Channel adaptation: {adaptation_status}")
127
  self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
128
  self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
129
  self.logger.info(f" Patch size: {self.model_config.patch_size}")
 
136
  self.logger.info(f" Total parameters: {total_params:,}")
137
  self.logger.info(f" Trainable parameters: {trainable_params:,}")
138
 
139
+ def forward(self, pilot_symbols: torch.Tensor, meta_data: Optional[Tuple] = None) -> torch.Tensor:
140
  """
141
  Forward pass for channel estimation.
142
 
143
  Args:
144
  pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
145
+ meta_data: Channel conditions (only used if channel adaptation is enabled)
146
 
147
  Returns:
148
  Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
149
  """
150
+ # Validate inputs based on adaptation mode
151
+ if self.use_channel_adaptation and meta_data is None:
152
+ raise ValueError("meta_data is required when channel adaptation is enabled")
153
+
154
+ if not self.use_channel_adaptation and meta_data is not None:
155
+ self.logger.warning("meta_data provided but channel adaptation is disabled - ignoring meta_data")
156
+
157
+ # Extract channel conditions if adaptation is enabled
158
+ channel_conditions = None
159
+ if self.use_channel_adaptation and meta_data is not None:
160
+ _, snr, delay_spread, max_dop_shift, _, _ = meta_data
161
+ channel_conditions = [
162
+ tensor.to(self.device)
163
+ for tensor in (snr, delay_spread, max_dop_shift)
164
+ ]
165
+
166
  # Ensure input is on correct device
167
  pilot_symbols = pilot_symbols.to(self.device)
168
 
169
  # Process real and imaginary parts separately
170
+ real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions)
171
+ imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions)
172
 
173
  # Combine into complex tensor
174
  channel_estimate = torch.complex(real_estimate, imag_estimate)
175
 
176
  return channel_estimate
177
 
178
+ def _forward_real_valued(self, x: torch.Tensor,
179
+ channel_conditions: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
180
  """
181
  Process real-valued input through the estimation pipeline.
182
 
183
  Args:
184
  x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
185
+ channel_conditions: Channel conditions for adaptation (optional)
186
 
187
  Returns:
188
  Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
 
205
  # Stage 3: Convert to patch embeddings
206
  patch_embeddings = self.patch_embedder(conv_enhanced)
207
 
208
+ # Stage 4: Apply channel adaptation if enabled
209
+ if self.use_channel_adaptation and channel_conditions is not None:
210
+ encoded_channel_condition = self.channel_adapter(*channel_conditions)
211
+ transformer_input = torch.cat((patch_embeddings, encoded_channel_condition), dim=2)
212
+ else:
213
+ transformer_input = patch_embeddings
214
+
215
+ # Stage 5: Transformer processing for long-range dependencies
216
+ transformer_output = self.transformer_encoder(transformer_input)
217
 
218
+ # Stage 6: Reconstruct spatial representation
219
  reconstructed = self.patch_reconstructor(transformer_output)
220
 
221
+ # Stage 7: Apply residual connection
222
  residual_combined = conv_enhanced + reconstructed
223
 
224
+ # Stage 8: Final convolutional refinement
225
  refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
226
 
227
  return refined_output
 
230
  """Return model configuration and statistics."""
231
  return {
232
  'model_name': self.__class__.__name__,
233
+ 'channel_adaptation': self.use_channel_adaptation,
234
  'ofdm_size': self.ofdm_size,
235
  'pilot_size': self.pilot_size,
236
  'patch_size': self.model_config.patch_size,
237
  'patch_length': self.patch_length,
238
+ 'adaptive_patch_length': self.adaptive_patch_length,
239
  'model_dim': self.model_config.model_dim,
240
  'num_layers': self.model_config.num_layers,
241
  'device': str(self.device),
242
  'total_parameters': sum(p.numel() for p in self.parameters()),
243
  'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
244
  }
245
+
246
+
247
+ class FortiTranEstimator(BaseFortiTranEstimator):
248
+ """
249
+ Standard Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
250
+
251
+ This is the base version without channel adaptation features.
252
+ """
253
+
254
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
255
+ """
256
+ Initialize the FortiTranEstimator.
257
+
258
+ Args:
259
+ system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
260
+ model_config: Model architecture configuration (patch size, layers, etc.)
261
+ """
262
+ super().__init__(system_config, model_config, use_channel_adaptation=False)
src/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for OFDM channel estimation."""
2
+ import re
3
+ import torch
4
+
5
+ def extract_values(file_name):
6
+ """
7
+ Extract channel information from a file name.
8
+
9
+ Parses file names with format:
10
+ '{number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
11
+
12
+ Args:
13
+ file_name: The file name from which to extract channel information
14
+
15
+ Returns:
16
+ tuple: A tuple containing:
17
+ - file_number (torch.Tensor): The file number
18
+ - snr (torch.Tensor): Signal-to-noise ratio value
19
+ - delay_spread (torch.Tensor): Delay spread value
20
+ - max_doppler_shift (torch.Tensor): Maximum Doppler shift value
21
+ - pilot_placement_frequency (torch.Tensor): Pilot placement frequency
22
+ - channel_type (list): The channel type
23
+
24
+ Raises:
25
+ ValueError: If the file name does not match the expected pattern
26
+ """
27
+ pattern = r'(\d+)_SNR-(\d+)_DS-(\d+)_DOP-(\d+)_N-(\d+)_([A-Z\-]+)\.mat'
28
+ match = re.match(pattern, file_name)
29
+ if match:
30
+ file_no = torch.tensor([int(match.group(1))], dtype=torch.float)
31
+ snr_value = torch.tensor([int(match.group(2))], dtype=torch.float)
32
+ ds_value = torch.tensor([int(match.group(3))], dtype=torch.float)
33
+ dop_value = torch.tensor([int(match.group(4))], dtype=torch.float)
34
+ n = torch.tensor([int(match.group(5))], dtype=torch.float)
35
+ channel_type = [match.group(6)]
36
+ return file_no, snr_value, ds_value, dop_value, n, channel_type
37
+ else:
38
+ raise ValueError("Cannot extract file information.")
39
+
40
+ def concat_complex_channel(channel_matrix):
41
+ """
42
+ Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
43
+
44
+ Transforms a complex tensor into a real-valued tensor by concatenating
45
+ the real and imaginary components along the specified dimension.
46
+
47
+ Args:
48
+ channel_matrix: Complex channel matrix
49
+
50
+ Returns:
51
+ Real-valued channel matrix with concatenated real and imaginary parts
52
+ """
53
+ real_channel_m = torch.real(channel_matrix)
54
+ imag_channel_m = torch.imag(channel_matrix)
55
+ cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
56
+ return cat_channel_m