BerkIGuler commited on
Commit
54d5c08
·
1 Parent(s): 0c90d5f

added parser and trainer blueprints

Browse files
config/{model_config.yaml → adafortitran.yaml} RENAMED
@@ -1,6 +1,5 @@
1
  patch_size: [3, 2]
2
  num_layers: 6
3
- device: "cpu"
4
  model_dim: 128
5
  num_head: 4
6
  activation: 'gelu'
 
1
  patch_size: [3, 2]
2
  num_layers: 6
 
3
  model_dim: 128
4
  num_head: 4
5
  activation: 'gelu'
config/fortitran.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ patch_size: [3, 2]
2
+ num_layers: 6
3
+ model_dim: 128
4
+ num_head: 4
5
+ activation: 'gelu'
6
+ dropout: 0.1
7
+ max_seq_len: 512
8
+ pos_encoding_type: 'learnable'
9
+ adaptive_token_length: 6
src/main/__init__.py ADDED
File without changes
src/main/parser.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Command line argument parser for OFDM channel estimation model training.
3
+
4
+ This module provides functionality for parsing and validating command-line arguments
5
+ used in training OFDM channel estimation models. It defines the available parameters,
6
+ their types, default values, and validation rules to ensure proper configuration
7
+ of training runs.
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ import argparse
13
+
14
+
15
+ @dataclass
16
+ class TrainingArguments:
17
+ """Container for OFDM model training arguments.
18
+
19
+ Stores, validates, and provides access to all parameters needed for
20
+ training an OFDM channel estimation model.
21
+
22
+ Attributes:
23
+ # Model Configuration
24
+ model_name: Supports Linear, AdaFortiTran, or FortiTran training
25
+ system_config_path: Path to OFDM system configuration file
26
+
27
+ # Dataset Paths
28
+ train_set: Path to training dataset directory
29
+ val_set: Path to validation dataset directory
30
+ test_set: Path to test dataset directory
31
+
32
+ # Experiment Settings
33
+ exp_id: Experiment identifier string
34
+ python_log_level: Logging verbosity level
35
+ tensorboard_log_dir: Directory for tensorboard logs
36
+
37
+ # Training Hyperparameters
38
+ batch_size: Number of samples per batch
39
+ lr: Learning rate for optimizer
40
+ max_epoch: Maximum number of training epochs
41
+ patience: Early stopping patience in epochs
42
+
43
+ # Hardware & Evaluation
44
+ cuda: CUDA device index
45
+ test_every_n: Number of epochs between test evaluations
46
+ """
47
+
48
+ # Model Configuration
49
+ model_name: str
50
+ system_config_path: Path
51
+
52
+ # Dataset Paths
53
+ train_set: Path
54
+ val_set: Path
55
+ test_set: Path
56
+
57
+ # Experiment Settings
58
+ exp_id: str
59
+ python_log_level: str = "INFO"
60
+ tensorboard_log_dir: Path = Path("runs")
61
+
62
+ # Training Hyperparameters
63
+ batch_size: int = 64
64
+ lr: float = 1e-3
65
+ max_epoch: int = 10
66
+ patience: int = 3
67
+
68
+ # Hardware & Evaluation
69
+ cuda: int = 0
70
+ test_every_n: int = 10
71
+
72
+ def __post_init__(self) -> None:
73
+ """Validate arguments after initialization.
74
+
75
+ Runs multiple validation checks on the provided arguments to ensure
76
+ they are consistent and valid for training.
77
+
78
+ Raises:
79
+ ValueError: If any validation check fails
80
+ """
81
+ self._validate_paths()
82
+ self._validate_numeric_args()
83
+
84
+ def _validate_paths(self) -> None:
85
+ """Validate path-related arguments.
86
+
87
+ Checks that the config file exists and has the correct extension.
88
+
89
+ Raises:
90
+ ValueError: If the config file doesn't exist or isn't a YAML file
91
+ """
92
+ if not self.system_config_path.exists():
93
+ raise ValueError(f"Config file not found: {self.system_config_path}")
94
+
95
+ if not self.system_config_path.suffix == '.yaml':
96
+ raise ValueError(f"Config file must be a .yaml file: {self.system_config_path}")
97
+
98
+ def _validate_numeric_args(self) -> None:
99
+ """Validate numeric arguments.
100
+
101
+ Ensures that all numeric parameters have appropriate values:
102
+ - test_every_n, max_epoch, patience, batch_size, lr must be positive
103
+ - cuda must be non-negative
104
+
105
+ Raises:
106
+ ValueError: If any numeric argument has an invalid value
107
+ """
108
+ if self.test_every_n <= 0:
109
+ raise ValueError(f"test_every_n must be positive, got: {self.test_every_n}")
110
+
111
+ if self.max_epoch <= 0:
112
+ raise ValueError(f"max_epoch must be positive, got: {self.max_epoch}")
113
+
114
+ if self.patience <= 0:
115
+ raise ValueError(f"patience must be positive, got: {self.patience}")
116
+
117
+ if self.batch_size <= 0:
118
+ raise ValueError(f"batch_size must be positive, got: {self.batch_size}")
119
+
120
+ if self.cuda < 0:
121
+ raise ValueError(f"cuda must be non-negative, got: {self.cuda}")
122
+
123
+ if self.lr <= 0:
124
+ raise ValueError(f"lr must be positive, got: {self.lr}")
125
+
126
+
127
+ def parse_arguments() -> TrainingArguments:
128
+ """Parse command-line arguments for training an OFDM channel estimation model.
129
+
130
+ Sets up an argument parser with all required and optional arguments,
131
+ processes the command line input, and returns a validated TrainingArguments
132
+ object with all parameters needed for model training.
133
+
134
+ Returns:
135
+ TrainingArguments: Validated arguments for model training
136
+
137
+ Raises:
138
+ ValueError: If validation fails for any arguments
139
+ SystemExit: If argument parsing fails (raised by argparse)
140
+ """
141
+
142
+ parser = argparse.ArgumentParser(
143
+ description='Train an OFDM channel estimation model',
144
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
145
+ )
146
+
147
+ # Required arguments
148
+ required = parser.add_argument_group('required arguments')
149
+ required.add_argument(
150
+ '--model_name',
151
+ type=str,
152
+ required=True,
153
+ choices=['Linear', 'AdaFortiTran', 'FortiTran'],
154
+ help='Model type to train (Linear, AdaFortiTran, or FortiTran)'
155
+ )
156
+ required.add_argument(
157
+ '--system_config_path',
158
+ type=Path,
159
+ required=True,
160
+ help='Path to YAML file containing OFDM system parameters'
161
+ )
162
+ required.add_argument(
163
+ '--train_set',
164
+ type=Path,
165
+ required=True,
166
+ help='Training dataset folder path'
167
+ )
168
+ required.add_argument(
169
+ '--val_set',
170
+ type=Path,
171
+ required=True,
172
+ help='Validation dataset folder path'
173
+ )
174
+ required.add_argument(
175
+ '--test_set',
176
+ type=Path,
177
+ required=True,
178
+ help='Test dataset folder path'
179
+ )
180
+ required.add_argument(
181
+ '--exp_id',
182
+ type=str,
183
+ required=True,
184
+ help='Experiment identifier for log folder naming'
185
+ )
186
+
187
+ # Optional arguments
188
+ optional = parser.add_argument_group('optional arguments')
189
+ optional.add_argument(
190
+ '--python_log_level',
191
+ type=str,
192
+ default="INFO",
193
+ help='Logger level for python logging module'
194
+ )
195
+ optional.add_argument(
196
+ '--tensorboard_log_dir',
197
+ type=Path,
198
+ default="runs",
199
+ help='Directory for tensorboard logs'
200
+ )
201
+ optional.add_argument(
202
+ '--test_every_n',
203
+ type=int,
204
+ default=10,
205
+ help='Test model every N epochs'
206
+ )
207
+ optional.add_argument(
208
+ '--max_epoch',
209
+ type=int,
210
+ default=10,
211
+ help='Maximum number of training epochs'
212
+ )
213
+ optional.add_argument(
214
+ '--patience',
215
+ type=int,
216
+ default=3,
217
+ help='Early stopping patience (epochs)'
218
+ )
219
+ optional.add_argument(
220
+ '--batch_size',
221
+ type=int,
222
+ default=64,
223
+ help='Training batch size'
224
+ )
225
+ optional.add_argument(
226
+ '--cuda',
227
+ type=int,
228
+ default=0,
229
+ help='CUDA device index (0 for single GPU)'
230
+ )
231
+ optional.add_argument(
232
+ '--lr',
233
+ type=float,
234
+ default=1e-3,
235
+ help='Initial learning rate'
236
+ )
237
+
238
+ args = parser.parse_args()
239
+
240
+ # Create and validate TrainingArguments
241
+ return TrainingArguments(**vars(args))
src/main/train_helpers.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training helper functions for OFDM channel estimation models.
3
+
4
+ This module provides utility functions for training, evaluating, and testing
5
+ deep learning models for OFDM channel estimation tasks. It includes functions
6
+ for performing training epochs, model evaluation, prediction generation,
7
+ and performance statistics calculation across different test conditions.
8
+ """
9
+
10
+ from typing import Dict, List, Tuple, Union, Callable
11
+ import torch
12
+ from torch import nn
13
+ from torch.utils.data import DataLoader
14
+ from torch.optim import Optimizer
15
+ from torch.optim.lr_scheduler import ExponentialLR
16
+ from src.utils import to_db, concat_complex_channel
17
+
18
+ # Type aliases
19
+ ComplexTensor = torch.Tensor # Complex tensor
20
+ BatchType = Tuple[ComplexTensor, ComplexTensor, Union[Dict, None]]
21
+ TestDataLoadersType = List[Tuple[str, DataLoader]]
22
+ StatsType = Dict[int, float]
23
+
24
+
25
+ def get_all_test_stats(
26
+ model: nn.Module,
27
+ test_dataloaders: Dict[str, TestDataLoadersType],
28
+ loss_fn: Callable
29
+ ) -> Tuple[StatsType, StatsType, StatsType]:
30
+ """
31
+ Evaluate model on all test datasets.
32
+
33
+ Calculates performance statistics (MSE in dB) for a model across different
34
+ test conditions: Delay Spread (DS), Max Doppler Shift (MDS), and
35
+ Signal-to-Noise Ratio (SNR).
36
+
37
+ Args:
38
+ model: Model to evaluate
39
+ test_dataloaders: Dictionary containing DataLoader objects for test sets:
40
+ - "DS": Delay Spread test set
41
+ - "MDS": Max Doppler Shift test set
42
+ - "SNR": Signal-to-Noise Ratio test set
43
+ loss_fn: Loss function for evaluation
44
+
45
+ Returns:
46
+ Tuple containing statistics (MSE in dB) for DS, MDS, and SNR test sets,
47
+ where each set of statistics is a dictionary mapping parameter values to MSE
48
+ """
49
+ ds_stats = get_test_stats(model, test_dataloaders["DS"], loss_fn)
50
+ mds_stats = get_test_stats(model, test_dataloaders["MDS"], loss_fn)
51
+ snr_stats = get_test_stats(model, test_dataloaders["SNR"], loss_fn)
52
+ return ds_stats, mds_stats, snr_stats
53
+
54
+
55
+ def get_test_stats(
56
+ model: nn.Module,
57
+ test_dataloaders: TestDataLoadersType,
58
+ loss_fn: Callable
59
+ ) -> StatsType:
60
+ """
61
+ Evaluate model on provided test dataloaders.
62
+
63
+ Calculates performance statistics (MSE in dB) for a model on a
64
+ specific set of test conditions.
65
+
66
+ Args:
67
+ model: Model to evaluate
68
+ test_dataloaders: List of (name, DataLoader) tuples for test sets,
69
+ where names are in format "parameter_value"
70
+ loss_fn: Loss function for evaluation
71
+
72
+ Returns:
73
+ Dictionary mapping test parameter values (as integers) to MSE values in dB
74
+ """
75
+ stats: StatsType = {}
76
+ sorted_loaders = sorted(
77
+ test_dataloaders,
78
+ key=lambda x: int(x[0].split("_")[1])
79
+ )
80
+
81
+ for name, test_dataloader in sorted_loaders:
82
+ var, val = name.split("_")
83
+ test_loss = eval_model(model, test_dataloader, loss_fn)
84
+ db_error = to_db(test_loss)
85
+ print(f"{var}:{val} Test MSE: {db_error:.4f} dB")
86
+ stats[int(val)] = db_error
87
+
88
+ return stats
89
+
90
+
91
+ def eval_model(
92
+ model: nn.Module,
93
+ eval_dataloader: DataLoader,
94
+ loss_fn: Callable
95
+ ) -> float:
96
+ """
97
+ Evaluate model on given dataloader.
98
+
99
+ Calculates the average loss for a model on a dataset without
100
+ performing parameter updates.
101
+
102
+ Args:
103
+ model: Model to evaluate
104
+ eval_dataloader: DataLoader containing evaluation data
105
+ loss_fn: Loss function for computing error
106
+
107
+ Returns:
108
+ Average validation loss (adjusted for complex values)
109
+
110
+ Notes:
111
+ Loss is multiplied by 2 to account for complex-valued matrices being
112
+ represented as real-valued matrices of double size.
113
+ """
114
+ val_loss = 0.0
115
+ model.eval()
116
+
117
+ with torch.no_grad():
118
+ for batch in eval_dataloader:
119
+ estimated_channel, ideal_channel = _forward_pass(batch, model)
120
+ output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
121
+ val_loss += (2 * output.item() * batch[0].size(0))
122
+
123
+ val_loss /= len(eval_dataloader.dataset)
124
+ return val_loss
125
+
126
+
127
+ def predict_channels(
128
+ model: nn.Module,
129
+ test_dataloaders: TestDataLoadersType
130
+ ) -> Dict[int, Dict[str, ComplexTensor]]:
131
+ """
132
+ Generate channel predictions for test datasets.
133
+
134
+ Creates predictions for a sample from each test dataset to enable
135
+ visualization and error analysis.
136
+
137
+ Args:
138
+ model: Model to use for predictions
139
+ test_dataloaders: List of (name, DataLoader) tuples for test sets,
140
+ where names are in format "parameter_value"
141
+
142
+ Returns:
143
+ Dictionary mapping test parameter values (as integers) to dictionaries containing
144
+ estimated and ideal channels for a single sample
145
+ """
146
+ channels: Dict[int, Dict[str, ComplexTensor]] = {}
147
+ sorted_loaders = sorted(
148
+ test_dataloaders,
149
+ key=lambda x: int(x[0].split("_")[1])
150
+ )
151
+
152
+ for name, test_dataloader in sorted_loaders:
153
+ with torch.no_grad():
154
+ batch = next(iter(test_dataloader))
155
+ estimated_channels, ideal_channels = _forward_pass(batch, model)
156
+
157
+ var, val = name.split("_")
158
+ channels[int(val)] = {
159
+ "estimated_channel": estimated_channels[0],
160
+ "ideal_channel": ideal_channels[0]
161
+ }
162
+
163
+ return channels
164
+
165
+
166
+ def train_epoch(
167
+ model: nn.Module,
168
+ optimizer: Optimizer,
169
+ loss_fn: Callable,
170
+ scheduler: ExponentialLR,
171
+ train_dataloader: DataLoader
172
+ ) -> float:
173
+ """
174
+ Train model for one epoch.
175
+
176
+ Performs a complete training iteration over the dataset, including:
177
+ - Forward pass through the model
178
+ - Loss calculation
179
+ - Backpropagation
180
+ - Parameter updates
181
+ - Learning rate scheduling
182
+
183
+ Args:
184
+ model: Model to train
185
+ optimizer: Optimizer for updating model parameters
186
+ loss_fn: Loss function for computing error
187
+ scheduler: Learning rate scheduler
188
+ train_dataloader: DataLoader containing training data
189
+
190
+ Returns:
191
+ Average training loss for the epoch (adjusted for complex values)
192
+
193
+ Notes:
194
+ Loss is multiplied by 2 to account for complex-valued matrices being
195
+ represented as real-valued matrices of double size.
196
+ """
197
+ train_loss = 0.0
198
+ model.train()
199
+
200
+ for batch in train_dataloader:
201
+ optimizer.zero_grad()
202
+ estimated_channel, ideal_channel = _forward_pass(batch, model)
203
+ output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
204
+ output.backward()
205
+ optimizer.step()
206
+ train_loss += (2 * output.item() * batch[0].size(0))
207
+
208
+ scheduler.step()
209
+ train_loss /= len(train_dataloader.dataset)
210
+ return train_loss
211
+
212
+
213
+ def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, ComplexTensor]:
214
+ """
215
+ Perform forward pass through model.
216
+
217
+ Processes input data through the appropriate model based on its type,
218
+ handling different input requirements for different model architectures.
219
+
220
+ Args:
221
+ batch: Tuple containing (estimated_channel, ideal_channel, metadata)
222
+ model: Model to use for processing
223
+
224
+ Returns:
225
+ Tuple of (processed_estimated_channel, ideal_channel)
226
+
227
+ Raises:
228
+ ValueError: If model name is not recognized
229
+ """
230
+ estimated_channel, ideal_channel, meta_data = batch
231
+
232
+ if model.name in ["fortitran", "MMSE"]:
233
+ h_est_re = model(torch.real(estimated_channel))
234
+ h_est_im = model(torch.imag(estimated_channel))
235
+ estimated_channel = torch.complex(h_est_re, h_est_im)
236
+ elif model.name == "adafortitran":
237
+ h_est_re = model(torch.real(estimated_channel), meta_data)
238
+ h_est_im = model(torch.imag(estimated_channel), meta_data)
239
+ estimated_channel = torch.complex(h_est_re, h_est_im)
240
+ else:
241
+ raise ValueError(f"Unknown model type: {model.name}")
242
+
243
+ return estimated_channel, ideal_channel.to(model.device)
244
+
245
+
246
+ def _compute_loss(
247
+ estimated_channel: ComplexTensor,
248
+ ideal_channel: ComplexTensor,
249
+ loss_fn: Callable
250
+ ) -> torch.Tensor:
251
+ """
252
+ Calculate loss between estimated and ideal channels.
253
+
254
+ Computes the loss between model output and ground truth using the specified
255
+ loss function, with appropriate handling of complex values.
256
+
257
+ Args:
258
+ estimated_channel: Estimated channel from model
259
+ ideal_channel: Ground truth ideal channel
260
+ loss_fn: Loss function to compute error
261
+
262
+ Returns:
263
+ Computed loss value as a scalar tensor
264
+ """
265
+ return loss_fn(
266
+ concat_complex_channel(estimated_channel),
267
+ concat_complex_channel(ideal_channel)
268
+ )
src/main/trainer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OFDM channel estimation model training module.
3
+
4
+ This module provides functionality for training and evaluating deep learning models
5
+ for OFDM channel estimation tasks. It includes a ModelTrainer class that handles
6
+ the complete training workflow, including model initialization, data loading,
7
+ training loop management, evaluation, and result logging.
8
+ """
9
+
10
+ import torch
11
+ from torch import nn, optim
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ from typing import Dict, Tuple, Type, Union
15
+
16
+ from .parser import TrainingArguments
17
+ from src.data.dataset import MatDataset, get_test_dataloaders
18
+ from src.models import LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator
19
+ from src.utils import (
20
+ EarlyStopping,
21
+ get_ls_mse_per_folder,
22
+ get_model_details,
23
+ get_test_stats_plot,
24
+ get_error_images
25
+ )
26
+ from src.main.train_helpers import (
27
+ get_all_test_stats,
28
+ train_epoch,
29
+ eval_model,
30
+ predict_channels
31
+ )
32
+
33
+ # A union type representing supported model classes
34
+ ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
35
+
36
+
37
+ class ModelTrainer:
38
+ """Handles the training and evaluation of deep learning models.
39
+
40
+ This class manages the complete lifecycle of model training, including:
41
+ - Model initialization and configuration
42
+ - Optimizer and loss function setup
43
+ - Data loading and preprocessing
44
+ - Training loop execution
45
+ - Performance evaluation
46
+ - Result logging and visualization via TensorBoard
47
+
48
+ Attributes:
49
+ MODEL_REGISTRY: Dictionary mapping model names to model classes
50
+ system_config: OFDM system configuration
51
+ args: Training arguments
52
+ device: PyTorch device for computation
53
+ writer: TensorBoard SummaryWriter for logging
54
+ model: Initialized model instance
55
+ optimizer: Torch optimizer for training
56
+ scheduler: Learning rate scheduler
57
+ early_stopper: Helper for early stopping
58
+ train_loader: DataLoader for training set
59
+ val_loader: DataLoader for validation set
60
+ test_loaders: Dictionary of test set DataLoaders
61
+ """
62
+
63
+ MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
64
+ "linear": LinearEstimator,
65
+ "adafortitran": AdaFortiTranEstimator,
66
+ "fortitran": FortiTranEstimator,
67
+ }
68
+
69
+ def __init__(self, system_config: Dict, args: TrainingArguments):
70
+ """
71
+ Initialize the ModelTrainer.
72
+
73
+ Args:
74
+ config: Model configuration dictionary from YAML file
75
+ args: Validated training arguments
76
+ """
77
+ self.system_config = system_config
78
+ self.args = args
79
+ self.device = torch.device(f"cuda:{args.cuda}")
80
+ self.writer = self._setup_tensorboard()
81
+
82
+ self.model = self._initialize_model()
83
+ self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
84
+ self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.995)
85
+ self.early_stopper = EarlyStopping(patience=args.patience)
86
+
87
+ self.training_loss = self._get_loss_function()
88
+ self.comparison_loss = nn.MSELoss() # used for test set evaluation
89
+
90
+ self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
91
+
92
+ def _get_loss_function(self) -> nn.Module:
93
+ """Get the appropriate loss function based on arguments.
94
+
95
+ Returns:
96
+ The selected PyTorch loss function based on args.loss_type
97
+
98
+ Raises:
99
+ ValueError: If an unsupported loss type is specified
100
+ """
101
+ if self.args.loss_type == LossType.MSE:
102
+ return nn.MSELoss()
103
+ elif self.args.loss_type == LossType.MAE:
104
+ return nn.L1Loss()
105
+ elif self.args.loss_type == LossType.HUBER:
106
+ return nn.HuberLoss()
107
+ else:
108
+ raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
109
+
110
+ def _setup_tensorboard(self) -> SummaryWriter:
111
+ """Set up TensorBoard logging.
112
+
113
+ Creates a unique log directory based on model name and experiment ID.
114
+
115
+ Returns:
116
+ Initialized SummaryWriter for TensorBoard logging
117
+
118
+ Raises:
119
+ RuntimeError: If experiment directory already exists
120
+ """
121
+ log_path = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
122
+ if log_path.exists():
123
+ raise RuntimeError(f"Experiment {log_path} already exists")
124
+
125
+ return SummaryWriter(str(log_path))
126
+
127
+ def _initialize_model(self) -> ModelType:
128
+ """Initialize the model based on configuration.
129
+
130
+ Creates an instance of the appropriate model class from the registry,
131
+ logs model summary information, and returns the initialized model.
132
+
133
+ Returns:
134
+ Initialized model instance of the specified type
135
+ """
136
+ model_class = self.MODEL_REGISTRY[self.args.model_name]
137
+ model = model_class(self.device, self.config, vars(self.args))
138
+
139
+ num_params, model_summary = get_model_details(model)
140
+ print(model_summary)
141
+ print(f"Model name: {self.config['model_name']}\nNumber of parameters: {num_params}")
142
+ self.writer.add_text("Number of Parameters", str(num_params))
143
+
144
+ return model
145
+
146
+ def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
147
+ """Initialize all required dataloaders.
148
+
149
+ Creates DataLoader instances for:
150
+ - Training dataset
151
+ - Validation dataset
152
+ - Test datasets grouped by test condition (DS, MDS, SNR)
153
+
154
+ Returns:
155
+ Tuple containing (train_loader, val_loader, test_loaders_dict)
156
+ """
157
+ # Training and validation dataloaders
158
+ train_dataset = MatDataset(
159
+ self.args.train_set,
160
+ self.args.pilot_dims,
161
+ return_type=self.config["return_type"]
162
+ )
163
+ val_dataset = MatDataset(
164
+ self.args.val_set,
165
+ self.args.pilot_dims,
166
+ return_type=self.config["return_type"]
167
+ )
168
+
169
+ train_loader = DataLoader(
170
+ train_dataset,
171
+ batch_size=self.args.batch_size,
172
+ shuffle=True
173
+ )
174
+ val_loader = DataLoader(
175
+ val_dataset,
176
+ batch_size=self.args.batch_size,
177
+ shuffle=True
178
+ )
179
+
180
+ # Test dataloaders
181
+ test_loaders = {
182
+ "DS": get_test_dataloaders(
183
+ self.args.test_set / "DS_test_set",
184
+ vars(self.args),
185
+ self.config["return_type"]
186
+ ),
187
+ "MDS": get_test_dataloaders(
188
+ self.args.test_set / "MDS_test_set",
189
+ vars(self.args),
190
+ self.config["return_type"]
191
+ ),
192
+ "SNR": get_test_dataloaders(
193
+ self.args.test_set / "SNR_test_set",
194
+ vars(self.args),
195
+ self.config["return_type"]
196
+ ),
197
+ }
198
+
199
+ return train_loader, val_loader, test_loaders
200
+
201
+ def _log_test_results(
202
+ self,
203
+ epoch: int,
204
+ test_stats: Dict[str, Dict],
205
+ ls_stats: Dict[str, Dict]
206
+ ) -> None:
207
+ """Log test results to TensorBoard.
208
+
209
+ Creates and logs visualizations comparing model performance against
210
+ baseline LS estimator across different test conditions.
211
+
212
+ Args:
213
+ epoch: Current training epoch
214
+ test_stats: Dictionary of test statistics for the model
215
+ ls_stats: Dictionary of test statistics for the LS baseline
216
+ """
217
+ for key in ("DS", "MDS", "SNR"):
218
+ # Plot test statistics
219
+ self.writer.add_figure(
220
+ tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
221
+ figure=get_test_stats_plot(
222
+ x_name=key,
223
+ stats=[test_stats[key], ls_stats[key]],
224
+ methods=[self.config["model_name"], "LS"]
225
+ )
226
+ )
227
+
228
+ # Plot error images
229
+ predicted_channels = predict_channels(
230
+ self.model,
231
+ self.test_loaders[key]
232
+ )
233
+ self.writer.add_figure(
234
+ tag=f"{key} Error Images (Epoch:{epoch + 1})",
235
+ figure=get_error_images(
236
+ key,
237
+ predicted_channels,
238
+ show=False
239
+ )
240
+ )
241
+
242
+ def _run_tests(self, epoch: int) -> None:
243
+ """Run tests and log results.
244
+
245
+ Evaluates the model on all test datasets, compares with LS baseline,
246
+ and logs performance metrics and visualizations.
247
+
248
+ Args:
249
+ epoch: Current training epoch
250
+ """
251
+ ds_stats, mds_stats, snr_stats = get_all_test_stats(
252
+ self.model,
253
+ self.test_loaders,
254
+ self.comparison_loss
255
+ )
256
+
257
+ ls_stats = {
258
+ "DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
259
+ "MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
260
+ "SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
261
+ }
262
+
263
+ test_stats = {
264
+ "DS": ds_stats,
265
+ "MDS": mds_stats,
266
+ "SNR": snr_stats
267
+ }
268
+
269
+ self._log_test_results(epoch, test_stats, ls_stats)
270
+
271
+ def _log_final_metrics(self, final_epoch: int) -> None:
272
+ """Log final training metrics and hyperparameters.
273
+
274
+ Records hyperparameters used in training and final performance metrics
275
+ across all test conditions for experiment tracking.
276
+
277
+ Args:
278
+ final_epoch: The index of the final training epoch
279
+ """
280
+ str_params = {k: str(v) for k, v in vars(self.args).items()}
281
+ self.writer.add_hparams(
282
+ hparam_dict=str_params,
283
+ metric_dict={"last_epoch": final_epoch + 1},
284
+ run_name="."
285
+ )
286
+
287
+ try:
288
+ for key in ("DS", "MDS", "SNR"):
289
+ ds_stats, mds_stats, snr_stats = get_all_test_stats(
290
+ self.model,
291
+ self.test_loaders,
292
+ self.comparison_loss
293
+ )
294
+ ls_stats = {
295
+ "DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
296
+ "MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
297
+ "SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
298
+ }
299
+
300
+ if key == "DS":
301
+ stats = ds_stats
302
+ elif key == "MDS":
303
+ stats = mds_stats
304
+ else:
305
+ stats = snr_stats
306
+
307
+ for val in stats.keys():
308
+ self.writer.add_scalars(
309
+ key,
310
+ {
311
+ "LS": ls_stats[key][val],
312
+ self.config["model_name"]: stats[val]
313
+ },
314
+ val
315
+ )
316
+ except Exception as e:
317
+ self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
318
+
319
+ def train(self) -> None:
320
+ """Execute the training loop.
321
+
322
+ Runs the complete training process including:
323
+ - Training and validation for each epoch
324
+ - Periodic testing based on test_every_n
325
+ - Early stopping when validation loss plateaus
326
+ - Logging final metrics and results
327
+ """
328
+ try:
329
+ from tqdm import tqdm
330
+ use_tqdm = True
331
+ except ImportError:
332
+ use_tqdm = False
333
+ print("tqdm not found, progress bar will not be displayed")
334
+
335
+ epoch = None
336
+
337
+ # Create progress bar if tqdm is available
338
+ if use_tqdm:
339
+ pbar = tqdm(range(self.args.max_epoch), desc="Training")
340
+ else:
341
+ pbar = range(self.args.max_epoch)
342
+
343
+ for epoch in pbar:
344
+ # Training step
345
+ train_loss = train_epoch(
346
+ self.model,
347
+ self.optimizer,
348
+ self.training_loss,
349
+ self.scheduler,
350
+ self.train_loader
351
+ )
352
+ self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
353
+
354
+ # Validation step
355
+ val_loss = eval_model(self.model, self.val_loader, self.training_loss)
356
+ self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
357
+
358
+ # Update progress bar with loss info if tqdm is available
359
+ if use_tqdm:
360
+ pbar.set_description(
361
+ f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
362
+
363
+ if self.early_stopper.early_stop(val_loss):
364
+ if use_tqdm:
365
+ pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
366
+ else:
367
+ print(f"Early stopping triggered at epoch {epoch + 1}")
368
+ break
369
+
370
+ # Periodic testing
371
+ if (epoch + 1) % self.args.test_every_n == 0:
372
+ message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
373
+ if use_tqdm:
374
+ pbar.write(message)
375
+ else:
376
+ print(message)
377
+ self._run_tests(epoch)
378
+
379
+ self._log_final_metrics(epoch)
380
+ self.writer.close()
381
+
382
+
383
+ def train(config: Dict, args: TrainingArguments) -> None:
384
+ """
385
+ Train an OFDM channel estimation model.
386
+
387
+ This is the main entry point for model training. It initializes a ModelTrainer
388
+ with the specified configuration and runs the training process.
389
+
390
+ Args:
391
+ config: Model configuration dictionary loaded from YAML file,
392
+ containing model architecture and training parameters
393
+ args: Validated training arguments containing all necessary parameters
394
+ for model training, including dataset paths, hyperparameters,
395
+ and logging configuration
396
+ """
397
+ trainer = ModelTrainer(config, args)
398
+ trainer.train()
src/models/__init__.py CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.models.fortitran import FortiTranEstimator
2
+ from src.models.adafortitran import AdaFortiTranEstimator
3
+ from src.models.linear import LinearEstimator
src/utils.py CHANGED
@@ -1,7 +1,70 @@
1
- """Utility functions for OFDM channel estimation."""
 
 
 
 
 
 
 
 
 
 
2
  import re
 
 
 
 
 
 
3
  import torch
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def extract_values(file_name):
6
  """
7
  Extract channel information from a file name.
@@ -37,6 +100,58 @@ def extract_values(file_name):
37
  else:
38
  raise ValueError("Cannot extract file information.")
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def concat_complex_channel(channel_matrix):
41
  """
42
  Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
@@ -54,3 +169,170 @@ def concat_complex_channel(channel_matrix):
54
  imag_channel_m = torch.imag(channel_matrix)
55
  cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
56
  return cat_channel_m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for OFDM channel estimation.
3
+
4
+ This module provides various utility functions for processing, visualizing,
5
+ and analyzing OFDM channel estimation data, including complex channel matrices,
6
+ error calculations, model statistics, and visualization tools for
7
+ performance evaluation.
8
+ """
9
+
10
+ from pathlib import Path
11
+ from typing import Optional, Union
12
  import re
13
+ import os
14
+
15
+ import numpy as np
16
+ import scipy.io as sio
17
+ import matplotlib.pyplot as plt
18
+ from prettytable import PrettyTable
19
  import torch
20
 
21
+
22
+ class EarlyStopping:
23
+ """Handles early stopping logic for training.
24
+
25
+ Monitors validation loss during training and signals when to stop
26
+ training if the loss has not improved for a specified number of epochs.
27
+
28
+ Attributes:
29
+ patience: Number of epochs to wait before stopping training
30
+ remaining_patience: Current remaining patience counter
31
+ min_loss: Minimum validation loss observed so far
32
+ """
33
+
34
+ def __init__(self, patience: int = 3):
35
+ """
36
+ Initialize early stopping.
37
+
38
+ Args:
39
+ patience: Number of epochs to wait before stopping
40
+ """
41
+ self.patience = patience
42
+ self.remaining_patience = patience
43
+ self.min_loss: Optional[float] = None
44
+
45
+ def early_stop(self, loss: float) -> bool:
46
+ """
47
+ Check if training should stop.
48
+
49
+ Args:
50
+ loss: Current validation loss
51
+
52
+ Returns:
53
+ Whether to stop training
54
+ """
55
+ if self.min_loss is None:
56
+ self.min_loss = loss
57
+ return False
58
+
59
+ if loss < self.min_loss:
60
+ self.min_loss = loss
61
+ self.remaining_patience = self.patience
62
+ return False
63
+
64
+ self.remaining_patience -= 1
65
+ return self.remaining_patience == 0
66
+
67
+
68
  def extract_values(file_name):
69
  """
70
  Extract channel information from a file name.
 
100
  else:
101
  raise ValueError("Cannot extract file information.")
102
 
103
+
104
+ def get_error_images(variable, channel_data, show=False):
105
+ """
106
+ Create visualizations of channel estimation errors.
107
+
108
+ Generates a figure with error heatmaps for different channel conditions,
109
+ showing the absolute difference between estimated and ideal channels.
110
+
111
+ Args:
112
+ variable: Name of the variable being visualized (e.g., 'SNR', 'DS')
113
+ channel_data: Dictionary mapping parameter values to dictionaries
114
+ containing 'estimated_channel' and 'ideal_channel'
115
+ show: Whether to display the figure immediately (default: False)
116
+
117
+ Returns:
118
+ matplotlib.figure.Figure: The generated figure with error heatmaps
119
+ """
120
+ # Create a figure with 7 subplots
121
+ fig, axes = plt.subplots(1, len(channel_data), figsize=(20, 6))
122
+
123
+ # Plot each subplot with consistent color scaling
124
+ for i, (key, channels) in enumerate(channel_data.items()):
125
+ # Calculate absolute error between estimated and ideal channels
126
+
127
+ estimated_channel = channels['estimated_channel']
128
+ ideal_channel = channels['ideal_channel']
129
+
130
+ error_matrix = torch.abs(estimated_channel - ideal_channel)
131
+ error_numpy = error_matrix.detach().cpu().numpy()
132
+
133
+ # Plot in the corresponding subplot with shared colormap limits
134
+ ax = axes[i]
135
+ cax = ax.imshow(error_numpy, cmap='viridis', aspect=14 / 120, vmin=0, vmax=1)
136
+ ax.set_title(f"{variable} = {key}")
137
+ ax.set_xlabel('Columns (14)')
138
+ ax.set_ylabel('Rows (120)')
139
+
140
+ # Create a new axis for the color bar to the right of the subplots
141
+ cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) # [left, bottom, width, height]
142
+ fig.colorbar(cax, cax=cbar_ax, label='Error Magnitude')
143
+
144
+ # Adjust layout to prevent overlapping labels
145
+ fig.tight_layout(rect=(0, 0, 0.9, 1)) # Leave space for the color bar on the right
146
+
147
+ # Show the figure if `show` is True
148
+ if show:
149
+ plt.show()
150
+
151
+ # Return the main figure
152
+ return fig
153
+
154
+
155
  def concat_complex_channel(channel_matrix):
156
  """
157
  Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
 
169
  imag_channel_m = torch.imag(channel_matrix)
170
  cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
171
  return cat_channel_m
172
+
173
+
174
+ def inverse_concat_complex_channel(channel_matrix: torch.Tensor) -> torch.Tensor:
175
+ """
176
+ Reconstruct complex channel matrix from concatenated real matrix.
177
+
178
+ Reverses the operation performed by concat_complex_channel by
179
+ splitting the tensor and combining the parts into a complex tensor.
180
+
181
+ Args:
182
+ channel_matrix: Real-valued matrix of shape (B, F, 2*T)
183
+
184
+ Returns:
185
+ Complex matrix of shape (B, F, T)
186
+ """
187
+ split_idx = channel_matrix.shape[-1] // 2
188
+ return torch.complex(
189
+ channel_matrix[:, :split_idx],
190
+ channel_matrix[:, split_idx:]
191
+ )
192
+
193
+
194
+ def get_test_stats_plot(x_name, stats, methods, show=False):
195
+ """
196
+ Plot test statistics for multiple methods as line graphs.
197
+
198
+ Creates a line plot comparing performance metrics (e.g., MSE) across
199
+ different conditions or parameters for multiple methods.
200
+
201
+ Args:
202
+ x_name: Label for the x-axis (e.g., 'SNR', 'DS', 'Epoch')
203
+ stats: List of dictionaries where each dictionary maps x-values to
204
+ performance metrics for a specific method
205
+ methods: List of method names corresponding to each entry in stats
206
+ show: Whether to display the plot immediately (default: False)
207
+
208
+ Returns:
209
+ matplotlib.figure.Figure: The generated figure object
210
+
211
+ Raises:
212
+ AssertionError: If stats and methods lists have different lengths
213
+ """
214
+ assert len(stats) == len(methods), "Provided stats and methods do not have the same length."
215
+ fig = plt.figure()
216
+ symbols = iter(["*", "x", "+", "D", "v", "^"])
217
+ for stat in stats:
218
+ try:
219
+ symbol = next(symbols)
220
+ except StopIteration:
221
+ symbols = iter(["o", "*", "x", "+", "D", "v", "^"])
222
+ symbol = next(symbols)
223
+
224
+ kv_pairs = sorted(list(stat.items()), key=lambda x: x[0])
225
+ x_vals = []
226
+ y_vals = []
227
+ for key, value in kv_pairs:
228
+ x_vals.append(key)
229
+ y_vals.append(value)
230
+
231
+ plt.plot(x_vals, y_vals, f"{symbol}--")
232
+ plt.xlabel(x_name)
233
+ plt.ylabel("MSE Error (dB)")
234
+ plt.grid()
235
+ plt.legend(methods)
236
+ if show:
237
+ plt.show()
238
+ return fig
239
+
240
+
241
+ def to_db(val):
242
+ """
243
+ Convert values to decibels (dB).
244
+
245
+ Applies the formula 10 * log10(val) to convert values to the decibel scale.
246
+
247
+ Args:
248
+ val: Input value or array to convert to dB (must be positive)
249
+
250
+ Returns:
251
+ The input value(s) converted to decibels
252
+ """
253
+ return 10 * np.log10(val)
254
+
255
+
256
+ def mse(x, y):
257
+ """
258
+ Calculate mean squared error (MSE) in dB between two complex arrays.
259
+
260
+ Computes the average squared magnitude of the difference between
261
+ two complex arrays and converts the result to decibels.
262
+
263
+ Args:
264
+ x: First complex numpy array
265
+ y: Second complex numpy array (same shape as x)
266
+
267
+ Returns:
268
+ MSE in decibels (dB) between the two arrays
269
+ """
270
+ mse_xy = np.mean(np.square(np.abs(x - y)))
271
+ mse_xy_db = to_db(mse_xy)
272
+ return mse_xy_db
273
+
274
+
275
+ def get_ls_mse_per_folder(folders_dir: Union[Path, str]):
276
+ """
277
+ Calculate average MSE for LS estimates in each subfolder.
278
+
279
+ For each subfolder in the specified directory, calculates the average
280
+ mean squared error between least-squares channel estimates and ideal
281
+ channel values across all .mat files in that subfolder.
282
+
283
+ Args:
284
+ folders_dir: Path to directory containing subfolders with .mat files
285
+ Each subfolder should be named 'prefix_val' where val is an integer
286
+
287
+ Returns:
288
+ Dictionary mapping integer values from subfolder names to average MSE values in dB
289
+
290
+ Notes:
291
+ - Each .mat file should contain a 3D matrix 'H' where:
292
+ - H[:,:,0] is the ideal channel
293
+ - H[:,:,2] is the LS channel estimate
294
+ - Subfolders are sorted by the integer in their names
295
+ """
296
+
297
+ mse_sums = {}
298
+ folders = os.listdir(folders_dir)
299
+ folders = sorted(folders, key=lambda x: int(x.split("_")[1]))
300
+ for folder in folders:
301
+ _, val = folder.split("_")
302
+ mse_sum = 0
303
+ folder_size = len(os.listdir(os.path.join(folders_dir, folder)))
304
+ for file in os.listdir(os.path.join(folders_dir, folder)):
305
+ mat_data = sio.loadmat(os.path.join(folders_dir, folder, file))['H']
306
+ ls_estimate = mat_data[:, :, 2]
307
+ ideal = mat_data[:, :, 0]
308
+ mse_sum += mse(ls_estimate, ideal)
309
+ mse_sum /= folder_size
310
+ mse_sums[int(val)] = mse_sum
311
+ return mse_sums
312
+
313
+ def get_model_details(model):
314
+ """
315
+ Get parameter counts and structure details for a PyTorch model.
316
+
317
+ Analyzes a PyTorch model to determine the total number of trainable
318
+ parameters and creates a formatted table showing the parameter count
319
+ for each named parameter in the model.
320
+
321
+ Args:
322
+ model: PyTorch model to analyze
323
+
324
+ Returns:
325
+ tuple containing:
326
+ - total_params: Total number of trainable parameters
327
+ - table: PrettyTable showing parameter counts by module
328
+ """
329
+
330
+ table = PrettyTable(["Modules", "Parameters"])
331
+ total_params = 0
332
+ for name, parameter in model.named_parameters():
333
+ if not parameter.requires_grad:
334
+ continue
335
+ params = parameter.numel()
336
+ table.add_row([name, params])
337
+ total_params += params
338
+ return total_params, table