Commit
·
54d5c08
1
Parent(s):
0c90d5f
added parser and trainer blueprints
Browse files- config/{model_config.yaml → adafortitran.yaml} +0 -1
- config/fortitran.yaml +9 -0
- src/main/__init__.py +0 -0
- src/main/parser.py +241 -0
- src/main/train_helpers.py +268 -0
- src/main/trainer.py +398 -0
- src/models/__init__.py +3 -0
- src/utils.py +283 -1
config/{model_config.yaml → adafortitran.yaml}
RENAMED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
patch_size: [3, 2]
|
| 2 |
num_layers: 6
|
| 3 |
-
device: "cpu"
|
| 4 |
model_dim: 128
|
| 5 |
num_head: 4
|
| 6 |
activation: 'gelu'
|
|
|
|
| 1 |
patch_size: [3, 2]
|
| 2 |
num_layers: 6
|
|
|
|
| 3 |
model_dim: 128
|
| 4 |
num_head: 4
|
| 5 |
activation: 'gelu'
|
config/fortitran.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
patch_size: [3, 2]
|
| 2 |
+
num_layers: 6
|
| 3 |
+
model_dim: 128
|
| 4 |
+
num_head: 4
|
| 5 |
+
activation: 'gelu'
|
| 6 |
+
dropout: 0.1
|
| 7 |
+
max_seq_len: 512
|
| 8 |
+
pos_encoding_type: 'learnable'
|
| 9 |
+
adaptive_token_length: 6
|
src/main/__init__.py
ADDED
|
File without changes
|
src/main/parser.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Command line argument parser for OFDM channel estimation model training.
|
| 3 |
+
|
| 4 |
+
This module provides functionality for parsing and validating command-line arguments
|
| 5 |
+
used in training OFDM channel estimation models. It defines the available parameters,
|
| 6 |
+
their types, default values, and validation rules to ensure proper configuration
|
| 7 |
+
of training runs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class TrainingArguments:
|
| 17 |
+
"""Container for OFDM model training arguments.
|
| 18 |
+
|
| 19 |
+
Stores, validates, and provides access to all parameters needed for
|
| 20 |
+
training an OFDM channel estimation model.
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
# Model Configuration
|
| 24 |
+
model_name: Supports Linear, AdaFortiTran, or FortiTran training
|
| 25 |
+
system_config_path: Path to OFDM system configuration file
|
| 26 |
+
|
| 27 |
+
# Dataset Paths
|
| 28 |
+
train_set: Path to training dataset directory
|
| 29 |
+
val_set: Path to validation dataset directory
|
| 30 |
+
test_set: Path to test dataset directory
|
| 31 |
+
|
| 32 |
+
# Experiment Settings
|
| 33 |
+
exp_id: Experiment identifier string
|
| 34 |
+
python_log_level: Logging verbosity level
|
| 35 |
+
tensorboard_log_dir: Directory for tensorboard logs
|
| 36 |
+
|
| 37 |
+
# Training Hyperparameters
|
| 38 |
+
batch_size: Number of samples per batch
|
| 39 |
+
lr: Learning rate for optimizer
|
| 40 |
+
max_epoch: Maximum number of training epochs
|
| 41 |
+
patience: Early stopping patience in epochs
|
| 42 |
+
|
| 43 |
+
# Hardware & Evaluation
|
| 44 |
+
cuda: CUDA device index
|
| 45 |
+
test_every_n: Number of epochs between test evaluations
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# Model Configuration
|
| 49 |
+
model_name: str
|
| 50 |
+
system_config_path: Path
|
| 51 |
+
|
| 52 |
+
# Dataset Paths
|
| 53 |
+
train_set: Path
|
| 54 |
+
val_set: Path
|
| 55 |
+
test_set: Path
|
| 56 |
+
|
| 57 |
+
# Experiment Settings
|
| 58 |
+
exp_id: str
|
| 59 |
+
python_log_level: str = "INFO"
|
| 60 |
+
tensorboard_log_dir: Path = Path("runs")
|
| 61 |
+
|
| 62 |
+
# Training Hyperparameters
|
| 63 |
+
batch_size: int = 64
|
| 64 |
+
lr: float = 1e-3
|
| 65 |
+
max_epoch: int = 10
|
| 66 |
+
patience: int = 3
|
| 67 |
+
|
| 68 |
+
# Hardware & Evaluation
|
| 69 |
+
cuda: int = 0
|
| 70 |
+
test_every_n: int = 10
|
| 71 |
+
|
| 72 |
+
def __post_init__(self) -> None:
|
| 73 |
+
"""Validate arguments after initialization.
|
| 74 |
+
|
| 75 |
+
Runs multiple validation checks on the provided arguments to ensure
|
| 76 |
+
they are consistent and valid for training.
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
ValueError: If any validation check fails
|
| 80 |
+
"""
|
| 81 |
+
self._validate_paths()
|
| 82 |
+
self._validate_numeric_args()
|
| 83 |
+
|
| 84 |
+
def _validate_paths(self) -> None:
|
| 85 |
+
"""Validate path-related arguments.
|
| 86 |
+
|
| 87 |
+
Checks that the config file exists and has the correct extension.
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
ValueError: If the config file doesn't exist or isn't a YAML file
|
| 91 |
+
"""
|
| 92 |
+
if not self.system_config_path.exists():
|
| 93 |
+
raise ValueError(f"Config file not found: {self.system_config_path}")
|
| 94 |
+
|
| 95 |
+
if not self.system_config_path.suffix == '.yaml':
|
| 96 |
+
raise ValueError(f"Config file must be a .yaml file: {self.system_config_path}")
|
| 97 |
+
|
| 98 |
+
def _validate_numeric_args(self) -> None:
|
| 99 |
+
"""Validate numeric arguments.
|
| 100 |
+
|
| 101 |
+
Ensures that all numeric parameters have appropriate values:
|
| 102 |
+
- test_every_n, max_epoch, patience, batch_size, lr must be positive
|
| 103 |
+
- cuda must be non-negative
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
ValueError: If any numeric argument has an invalid value
|
| 107 |
+
"""
|
| 108 |
+
if self.test_every_n <= 0:
|
| 109 |
+
raise ValueError(f"test_every_n must be positive, got: {self.test_every_n}")
|
| 110 |
+
|
| 111 |
+
if self.max_epoch <= 0:
|
| 112 |
+
raise ValueError(f"max_epoch must be positive, got: {self.max_epoch}")
|
| 113 |
+
|
| 114 |
+
if self.patience <= 0:
|
| 115 |
+
raise ValueError(f"patience must be positive, got: {self.patience}")
|
| 116 |
+
|
| 117 |
+
if self.batch_size <= 0:
|
| 118 |
+
raise ValueError(f"batch_size must be positive, got: {self.batch_size}")
|
| 119 |
+
|
| 120 |
+
if self.cuda < 0:
|
| 121 |
+
raise ValueError(f"cuda must be non-negative, got: {self.cuda}")
|
| 122 |
+
|
| 123 |
+
if self.lr <= 0:
|
| 124 |
+
raise ValueError(f"lr must be positive, got: {self.lr}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def parse_arguments() -> TrainingArguments:
|
| 128 |
+
"""Parse command-line arguments for training an OFDM channel estimation model.
|
| 129 |
+
|
| 130 |
+
Sets up an argument parser with all required and optional arguments,
|
| 131 |
+
processes the command line input, and returns a validated TrainingArguments
|
| 132 |
+
object with all parameters needed for model training.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
TrainingArguments: Validated arguments for model training
|
| 136 |
+
|
| 137 |
+
Raises:
|
| 138 |
+
ValueError: If validation fails for any arguments
|
| 139 |
+
SystemExit: If argument parsing fails (raised by argparse)
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
parser = argparse.ArgumentParser(
|
| 143 |
+
description='Train an OFDM channel estimation model',
|
| 144 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Required arguments
|
| 148 |
+
required = parser.add_argument_group('required arguments')
|
| 149 |
+
required.add_argument(
|
| 150 |
+
'--model_name',
|
| 151 |
+
type=str,
|
| 152 |
+
required=True,
|
| 153 |
+
choices=['Linear', 'AdaFortiTran', 'FortiTran'],
|
| 154 |
+
help='Model type to train (Linear, AdaFortiTran, or FortiTran)'
|
| 155 |
+
)
|
| 156 |
+
required.add_argument(
|
| 157 |
+
'--system_config_path',
|
| 158 |
+
type=Path,
|
| 159 |
+
required=True,
|
| 160 |
+
help='Path to YAML file containing OFDM system parameters'
|
| 161 |
+
)
|
| 162 |
+
required.add_argument(
|
| 163 |
+
'--train_set',
|
| 164 |
+
type=Path,
|
| 165 |
+
required=True,
|
| 166 |
+
help='Training dataset folder path'
|
| 167 |
+
)
|
| 168 |
+
required.add_argument(
|
| 169 |
+
'--val_set',
|
| 170 |
+
type=Path,
|
| 171 |
+
required=True,
|
| 172 |
+
help='Validation dataset folder path'
|
| 173 |
+
)
|
| 174 |
+
required.add_argument(
|
| 175 |
+
'--test_set',
|
| 176 |
+
type=Path,
|
| 177 |
+
required=True,
|
| 178 |
+
help='Test dataset folder path'
|
| 179 |
+
)
|
| 180 |
+
required.add_argument(
|
| 181 |
+
'--exp_id',
|
| 182 |
+
type=str,
|
| 183 |
+
required=True,
|
| 184 |
+
help='Experiment identifier for log folder naming'
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Optional arguments
|
| 188 |
+
optional = parser.add_argument_group('optional arguments')
|
| 189 |
+
optional.add_argument(
|
| 190 |
+
'--python_log_level',
|
| 191 |
+
type=str,
|
| 192 |
+
default="INFO",
|
| 193 |
+
help='Logger level for python logging module'
|
| 194 |
+
)
|
| 195 |
+
optional.add_argument(
|
| 196 |
+
'--tensorboard_log_dir',
|
| 197 |
+
type=Path,
|
| 198 |
+
default="runs",
|
| 199 |
+
help='Directory for tensorboard logs'
|
| 200 |
+
)
|
| 201 |
+
optional.add_argument(
|
| 202 |
+
'--test_every_n',
|
| 203 |
+
type=int,
|
| 204 |
+
default=10,
|
| 205 |
+
help='Test model every N epochs'
|
| 206 |
+
)
|
| 207 |
+
optional.add_argument(
|
| 208 |
+
'--max_epoch',
|
| 209 |
+
type=int,
|
| 210 |
+
default=10,
|
| 211 |
+
help='Maximum number of training epochs'
|
| 212 |
+
)
|
| 213 |
+
optional.add_argument(
|
| 214 |
+
'--patience',
|
| 215 |
+
type=int,
|
| 216 |
+
default=3,
|
| 217 |
+
help='Early stopping patience (epochs)'
|
| 218 |
+
)
|
| 219 |
+
optional.add_argument(
|
| 220 |
+
'--batch_size',
|
| 221 |
+
type=int,
|
| 222 |
+
default=64,
|
| 223 |
+
help='Training batch size'
|
| 224 |
+
)
|
| 225 |
+
optional.add_argument(
|
| 226 |
+
'--cuda',
|
| 227 |
+
type=int,
|
| 228 |
+
default=0,
|
| 229 |
+
help='CUDA device index (0 for single GPU)'
|
| 230 |
+
)
|
| 231 |
+
optional.add_argument(
|
| 232 |
+
'--lr',
|
| 233 |
+
type=float,
|
| 234 |
+
default=1e-3,
|
| 235 |
+
help='Initial learning rate'
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
# Create and validate TrainingArguments
|
| 241 |
+
return TrainingArguments(**vars(args))
|
src/main/train_helpers.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training helper functions for OFDM channel estimation models.
|
| 3 |
+
|
| 4 |
+
This module provides utility functions for training, evaluating, and testing
|
| 5 |
+
deep learning models for OFDM channel estimation tasks. It includes functions
|
| 6 |
+
for performing training epochs, model evaluation, prediction generation,
|
| 7 |
+
and performance statistics calculation across different test conditions.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Tuple, Union, Callable
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from torch.optim import Optimizer
|
| 15 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
| 16 |
+
from src.utils import to_db, concat_complex_channel
|
| 17 |
+
|
| 18 |
+
# Type aliases
|
| 19 |
+
ComplexTensor = torch.Tensor # Complex tensor
|
| 20 |
+
BatchType = Tuple[ComplexTensor, ComplexTensor, Union[Dict, None]]
|
| 21 |
+
TestDataLoadersType = List[Tuple[str, DataLoader]]
|
| 22 |
+
StatsType = Dict[int, float]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_all_test_stats(
|
| 26 |
+
model: nn.Module,
|
| 27 |
+
test_dataloaders: Dict[str, TestDataLoadersType],
|
| 28 |
+
loss_fn: Callable
|
| 29 |
+
) -> Tuple[StatsType, StatsType, StatsType]:
|
| 30 |
+
"""
|
| 31 |
+
Evaluate model on all test datasets.
|
| 32 |
+
|
| 33 |
+
Calculates performance statistics (MSE in dB) for a model across different
|
| 34 |
+
test conditions: Delay Spread (DS), Max Doppler Shift (MDS), and
|
| 35 |
+
Signal-to-Noise Ratio (SNR).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model: Model to evaluate
|
| 39 |
+
test_dataloaders: Dictionary containing DataLoader objects for test sets:
|
| 40 |
+
- "DS": Delay Spread test set
|
| 41 |
+
- "MDS": Max Doppler Shift test set
|
| 42 |
+
- "SNR": Signal-to-Noise Ratio test set
|
| 43 |
+
loss_fn: Loss function for evaluation
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple containing statistics (MSE in dB) for DS, MDS, and SNR test sets,
|
| 47 |
+
where each set of statistics is a dictionary mapping parameter values to MSE
|
| 48 |
+
"""
|
| 49 |
+
ds_stats = get_test_stats(model, test_dataloaders["DS"], loss_fn)
|
| 50 |
+
mds_stats = get_test_stats(model, test_dataloaders["MDS"], loss_fn)
|
| 51 |
+
snr_stats = get_test_stats(model, test_dataloaders["SNR"], loss_fn)
|
| 52 |
+
return ds_stats, mds_stats, snr_stats
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_test_stats(
|
| 56 |
+
model: nn.Module,
|
| 57 |
+
test_dataloaders: TestDataLoadersType,
|
| 58 |
+
loss_fn: Callable
|
| 59 |
+
) -> StatsType:
|
| 60 |
+
"""
|
| 61 |
+
Evaluate model on provided test dataloaders.
|
| 62 |
+
|
| 63 |
+
Calculates performance statistics (MSE in dB) for a model on a
|
| 64 |
+
specific set of test conditions.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
model: Model to evaluate
|
| 68 |
+
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
| 69 |
+
where names are in format "parameter_value"
|
| 70 |
+
loss_fn: Loss function for evaluation
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Dictionary mapping test parameter values (as integers) to MSE values in dB
|
| 74 |
+
"""
|
| 75 |
+
stats: StatsType = {}
|
| 76 |
+
sorted_loaders = sorted(
|
| 77 |
+
test_dataloaders,
|
| 78 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
for name, test_dataloader in sorted_loaders:
|
| 82 |
+
var, val = name.split("_")
|
| 83 |
+
test_loss = eval_model(model, test_dataloader, loss_fn)
|
| 84 |
+
db_error = to_db(test_loss)
|
| 85 |
+
print(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
| 86 |
+
stats[int(val)] = db_error
|
| 87 |
+
|
| 88 |
+
return stats
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def eval_model(
|
| 92 |
+
model: nn.Module,
|
| 93 |
+
eval_dataloader: DataLoader,
|
| 94 |
+
loss_fn: Callable
|
| 95 |
+
) -> float:
|
| 96 |
+
"""
|
| 97 |
+
Evaluate model on given dataloader.
|
| 98 |
+
|
| 99 |
+
Calculates the average loss for a model on a dataset without
|
| 100 |
+
performing parameter updates.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
model: Model to evaluate
|
| 104 |
+
eval_dataloader: DataLoader containing evaluation data
|
| 105 |
+
loss_fn: Loss function for computing error
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Average validation loss (adjusted for complex values)
|
| 109 |
+
|
| 110 |
+
Notes:
|
| 111 |
+
Loss is multiplied by 2 to account for complex-valued matrices being
|
| 112 |
+
represented as real-valued matrices of double size.
|
| 113 |
+
"""
|
| 114 |
+
val_loss = 0.0
|
| 115 |
+
model.eval()
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
for batch in eval_dataloader:
|
| 119 |
+
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
| 120 |
+
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 121 |
+
val_loss += (2 * output.item() * batch[0].size(0))
|
| 122 |
+
|
| 123 |
+
val_loss /= len(eval_dataloader.dataset)
|
| 124 |
+
return val_loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def predict_channels(
|
| 128 |
+
model: nn.Module,
|
| 129 |
+
test_dataloaders: TestDataLoadersType
|
| 130 |
+
) -> Dict[int, Dict[str, ComplexTensor]]:
|
| 131 |
+
"""
|
| 132 |
+
Generate channel predictions for test datasets.
|
| 133 |
+
|
| 134 |
+
Creates predictions for a sample from each test dataset to enable
|
| 135 |
+
visualization and error analysis.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
model: Model to use for predictions
|
| 139 |
+
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
| 140 |
+
where names are in format "parameter_value"
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Dictionary mapping test parameter values (as integers) to dictionaries containing
|
| 144 |
+
estimated and ideal channels for a single sample
|
| 145 |
+
"""
|
| 146 |
+
channels: Dict[int, Dict[str, ComplexTensor]] = {}
|
| 147 |
+
sorted_loaders = sorted(
|
| 148 |
+
test_dataloaders,
|
| 149 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
for name, test_dataloader in sorted_loaders:
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
batch = next(iter(test_dataloader))
|
| 155 |
+
estimated_channels, ideal_channels = _forward_pass(batch, model)
|
| 156 |
+
|
| 157 |
+
var, val = name.split("_")
|
| 158 |
+
channels[int(val)] = {
|
| 159 |
+
"estimated_channel": estimated_channels[0],
|
| 160 |
+
"ideal_channel": ideal_channels[0]
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
return channels
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def train_epoch(
|
| 167 |
+
model: nn.Module,
|
| 168 |
+
optimizer: Optimizer,
|
| 169 |
+
loss_fn: Callable,
|
| 170 |
+
scheduler: ExponentialLR,
|
| 171 |
+
train_dataloader: DataLoader
|
| 172 |
+
) -> float:
|
| 173 |
+
"""
|
| 174 |
+
Train model for one epoch.
|
| 175 |
+
|
| 176 |
+
Performs a complete training iteration over the dataset, including:
|
| 177 |
+
- Forward pass through the model
|
| 178 |
+
- Loss calculation
|
| 179 |
+
- Backpropagation
|
| 180 |
+
- Parameter updates
|
| 181 |
+
- Learning rate scheduling
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
model: Model to train
|
| 185 |
+
optimizer: Optimizer for updating model parameters
|
| 186 |
+
loss_fn: Loss function for computing error
|
| 187 |
+
scheduler: Learning rate scheduler
|
| 188 |
+
train_dataloader: DataLoader containing training data
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Average training loss for the epoch (adjusted for complex values)
|
| 192 |
+
|
| 193 |
+
Notes:
|
| 194 |
+
Loss is multiplied by 2 to account for complex-valued matrices being
|
| 195 |
+
represented as real-valued matrices of double size.
|
| 196 |
+
"""
|
| 197 |
+
train_loss = 0.0
|
| 198 |
+
model.train()
|
| 199 |
+
|
| 200 |
+
for batch in train_dataloader:
|
| 201 |
+
optimizer.zero_grad()
|
| 202 |
+
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
| 203 |
+
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 204 |
+
output.backward()
|
| 205 |
+
optimizer.step()
|
| 206 |
+
train_loss += (2 * output.item() * batch[0].size(0))
|
| 207 |
+
|
| 208 |
+
scheduler.step()
|
| 209 |
+
train_loss /= len(train_dataloader.dataset)
|
| 210 |
+
return train_loss
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, ComplexTensor]:
|
| 214 |
+
"""
|
| 215 |
+
Perform forward pass through model.
|
| 216 |
+
|
| 217 |
+
Processes input data through the appropriate model based on its type,
|
| 218 |
+
handling different input requirements for different model architectures.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
batch: Tuple containing (estimated_channel, ideal_channel, metadata)
|
| 222 |
+
model: Model to use for processing
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Tuple of (processed_estimated_channel, ideal_channel)
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: If model name is not recognized
|
| 229 |
+
"""
|
| 230 |
+
estimated_channel, ideal_channel, meta_data = batch
|
| 231 |
+
|
| 232 |
+
if model.name in ["fortitran", "MMSE"]:
|
| 233 |
+
h_est_re = model(torch.real(estimated_channel))
|
| 234 |
+
h_est_im = model(torch.imag(estimated_channel))
|
| 235 |
+
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 236 |
+
elif model.name == "adafortitran":
|
| 237 |
+
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 238 |
+
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 239 |
+
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f"Unknown model type: {model.name}")
|
| 242 |
+
|
| 243 |
+
return estimated_channel, ideal_channel.to(model.device)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _compute_loss(
|
| 247 |
+
estimated_channel: ComplexTensor,
|
| 248 |
+
ideal_channel: ComplexTensor,
|
| 249 |
+
loss_fn: Callable
|
| 250 |
+
) -> torch.Tensor:
|
| 251 |
+
"""
|
| 252 |
+
Calculate loss between estimated and ideal channels.
|
| 253 |
+
|
| 254 |
+
Computes the loss between model output and ground truth using the specified
|
| 255 |
+
loss function, with appropriate handling of complex values.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
estimated_channel: Estimated channel from model
|
| 259 |
+
ideal_channel: Ground truth ideal channel
|
| 260 |
+
loss_fn: Loss function to compute error
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Computed loss value as a scalar tensor
|
| 264 |
+
"""
|
| 265 |
+
return loss_fn(
|
| 266 |
+
concat_complex_channel(estimated_channel),
|
| 267 |
+
concat_complex_channel(ideal_channel)
|
| 268 |
+
)
|
src/main/trainer.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OFDM channel estimation model training module.
|
| 3 |
+
|
| 4 |
+
This module provides functionality for training and evaluating deep learning models
|
| 5 |
+
for OFDM channel estimation tasks. It includes a ModelTrainer class that handles
|
| 6 |
+
the complete training workflow, including model initialization, data loading,
|
| 7 |
+
training loop management, evaluation, and result logging.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, optim
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
from typing import Dict, Tuple, Type, Union
|
| 15 |
+
|
| 16 |
+
from .parser import TrainingArguments
|
| 17 |
+
from src.data.dataset import MatDataset, get_test_dataloaders
|
| 18 |
+
from src.models import LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator
|
| 19 |
+
from src.utils import (
|
| 20 |
+
EarlyStopping,
|
| 21 |
+
get_ls_mse_per_folder,
|
| 22 |
+
get_model_details,
|
| 23 |
+
get_test_stats_plot,
|
| 24 |
+
get_error_images
|
| 25 |
+
)
|
| 26 |
+
from src.main.train_helpers import (
|
| 27 |
+
get_all_test_stats,
|
| 28 |
+
train_epoch,
|
| 29 |
+
eval_model,
|
| 30 |
+
predict_channels
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# A union type representing supported model classes
|
| 34 |
+
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ModelTrainer:
|
| 38 |
+
"""Handles the training and evaluation of deep learning models.
|
| 39 |
+
|
| 40 |
+
This class manages the complete lifecycle of model training, including:
|
| 41 |
+
- Model initialization and configuration
|
| 42 |
+
- Optimizer and loss function setup
|
| 43 |
+
- Data loading and preprocessing
|
| 44 |
+
- Training loop execution
|
| 45 |
+
- Performance evaluation
|
| 46 |
+
- Result logging and visualization via TensorBoard
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
MODEL_REGISTRY: Dictionary mapping model names to model classes
|
| 50 |
+
system_config: OFDM system configuration
|
| 51 |
+
args: Training arguments
|
| 52 |
+
device: PyTorch device for computation
|
| 53 |
+
writer: TensorBoard SummaryWriter for logging
|
| 54 |
+
model: Initialized model instance
|
| 55 |
+
optimizer: Torch optimizer for training
|
| 56 |
+
scheduler: Learning rate scheduler
|
| 57 |
+
early_stopper: Helper for early stopping
|
| 58 |
+
train_loader: DataLoader for training set
|
| 59 |
+
val_loader: DataLoader for validation set
|
| 60 |
+
test_loaders: Dictionary of test set DataLoaders
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
| 64 |
+
"linear": LinearEstimator,
|
| 65 |
+
"adafortitran": AdaFortiTranEstimator,
|
| 66 |
+
"fortitran": FortiTranEstimator,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def __init__(self, system_config: Dict, args: TrainingArguments):
|
| 70 |
+
"""
|
| 71 |
+
Initialize the ModelTrainer.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
config: Model configuration dictionary from YAML file
|
| 75 |
+
args: Validated training arguments
|
| 76 |
+
"""
|
| 77 |
+
self.system_config = system_config
|
| 78 |
+
self.args = args
|
| 79 |
+
self.device = torch.device(f"cuda:{args.cuda}")
|
| 80 |
+
self.writer = self._setup_tensorboard()
|
| 81 |
+
|
| 82 |
+
self.model = self._initialize_model()
|
| 83 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
|
| 84 |
+
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.995)
|
| 85 |
+
self.early_stopper = EarlyStopping(patience=args.patience)
|
| 86 |
+
|
| 87 |
+
self.training_loss = self._get_loss_function()
|
| 88 |
+
self.comparison_loss = nn.MSELoss() # used for test set evaluation
|
| 89 |
+
|
| 90 |
+
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
| 91 |
+
|
| 92 |
+
def _get_loss_function(self) -> nn.Module:
|
| 93 |
+
"""Get the appropriate loss function based on arguments.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
The selected PyTorch loss function based on args.loss_type
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
ValueError: If an unsupported loss type is specified
|
| 100 |
+
"""
|
| 101 |
+
if self.args.loss_type == LossType.MSE:
|
| 102 |
+
return nn.MSELoss()
|
| 103 |
+
elif self.args.loss_type == LossType.MAE:
|
| 104 |
+
return nn.L1Loss()
|
| 105 |
+
elif self.args.loss_type == LossType.HUBER:
|
| 106 |
+
return nn.HuberLoss()
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
|
| 109 |
+
|
| 110 |
+
def _setup_tensorboard(self) -> SummaryWriter:
|
| 111 |
+
"""Set up TensorBoard logging.
|
| 112 |
+
|
| 113 |
+
Creates a unique log directory based on model name and experiment ID.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Initialized SummaryWriter for TensorBoard logging
|
| 117 |
+
|
| 118 |
+
Raises:
|
| 119 |
+
RuntimeError: If experiment directory already exists
|
| 120 |
+
"""
|
| 121 |
+
log_path = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
|
| 122 |
+
if log_path.exists():
|
| 123 |
+
raise RuntimeError(f"Experiment {log_path} already exists")
|
| 124 |
+
|
| 125 |
+
return SummaryWriter(str(log_path))
|
| 126 |
+
|
| 127 |
+
def _initialize_model(self) -> ModelType:
|
| 128 |
+
"""Initialize the model based on configuration.
|
| 129 |
+
|
| 130 |
+
Creates an instance of the appropriate model class from the registry,
|
| 131 |
+
logs model summary information, and returns the initialized model.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Initialized model instance of the specified type
|
| 135 |
+
"""
|
| 136 |
+
model_class = self.MODEL_REGISTRY[self.args.model_name]
|
| 137 |
+
model = model_class(self.device, self.config, vars(self.args))
|
| 138 |
+
|
| 139 |
+
num_params, model_summary = get_model_details(model)
|
| 140 |
+
print(model_summary)
|
| 141 |
+
print(f"Model name: {self.config['model_name']}\nNumber of parameters: {num_params}")
|
| 142 |
+
self.writer.add_text("Number of Parameters", str(num_params))
|
| 143 |
+
|
| 144 |
+
return model
|
| 145 |
+
|
| 146 |
+
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
| 147 |
+
"""Initialize all required dataloaders.
|
| 148 |
+
|
| 149 |
+
Creates DataLoader instances for:
|
| 150 |
+
- Training dataset
|
| 151 |
+
- Validation dataset
|
| 152 |
+
- Test datasets grouped by test condition (DS, MDS, SNR)
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Tuple containing (train_loader, val_loader, test_loaders_dict)
|
| 156 |
+
"""
|
| 157 |
+
# Training and validation dataloaders
|
| 158 |
+
train_dataset = MatDataset(
|
| 159 |
+
self.args.train_set,
|
| 160 |
+
self.args.pilot_dims,
|
| 161 |
+
return_type=self.config["return_type"]
|
| 162 |
+
)
|
| 163 |
+
val_dataset = MatDataset(
|
| 164 |
+
self.args.val_set,
|
| 165 |
+
self.args.pilot_dims,
|
| 166 |
+
return_type=self.config["return_type"]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
train_loader = DataLoader(
|
| 170 |
+
train_dataset,
|
| 171 |
+
batch_size=self.args.batch_size,
|
| 172 |
+
shuffle=True
|
| 173 |
+
)
|
| 174 |
+
val_loader = DataLoader(
|
| 175 |
+
val_dataset,
|
| 176 |
+
batch_size=self.args.batch_size,
|
| 177 |
+
shuffle=True
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Test dataloaders
|
| 181 |
+
test_loaders = {
|
| 182 |
+
"DS": get_test_dataloaders(
|
| 183 |
+
self.args.test_set / "DS_test_set",
|
| 184 |
+
vars(self.args),
|
| 185 |
+
self.config["return_type"]
|
| 186 |
+
),
|
| 187 |
+
"MDS": get_test_dataloaders(
|
| 188 |
+
self.args.test_set / "MDS_test_set",
|
| 189 |
+
vars(self.args),
|
| 190 |
+
self.config["return_type"]
|
| 191 |
+
),
|
| 192 |
+
"SNR": get_test_dataloaders(
|
| 193 |
+
self.args.test_set / "SNR_test_set",
|
| 194 |
+
vars(self.args),
|
| 195 |
+
self.config["return_type"]
|
| 196 |
+
),
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
return train_loader, val_loader, test_loaders
|
| 200 |
+
|
| 201 |
+
def _log_test_results(
|
| 202 |
+
self,
|
| 203 |
+
epoch: int,
|
| 204 |
+
test_stats: Dict[str, Dict],
|
| 205 |
+
ls_stats: Dict[str, Dict]
|
| 206 |
+
) -> None:
|
| 207 |
+
"""Log test results to TensorBoard.
|
| 208 |
+
|
| 209 |
+
Creates and logs visualizations comparing model performance against
|
| 210 |
+
baseline LS estimator across different test conditions.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
epoch: Current training epoch
|
| 214 |
+
test_stats: Dictionary of test statistics for the model
|
| 215 |
+
ls_stats: Dictionary of test statistics for the LS baseline
|
| 216 |
+
"""
|
| 217 |
+
for key in ("DS", "MDS", "SNR"):
|
| 218 |
+
# Plot test statistics
|
| 219 |
+
self.writer.add_figure(
|
| 220 |
+
tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
|
| 221 |
+
figure=get_test_stats_plot(
|
| 222 |
+
x_name=key,
|
| 223 |
+
stats=[test_stats[key], ls_stats[key]],
|
| 224 |
+
methods=[self.config["model_name"], "LS"]
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Plot error images
|
| 229 |
+
predicted_channels = predict_channels(
|
| 230 |
+
self.model,
|
| 231 |
+
self.test_loaders[key]
|
| 232 |
+
)
|
| 233 |
+
self.writer.add_figure(
|
| 234 |
+
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
| 235 |
+
figure=get_error_images(
|
| 236 |
+
key,
|
| 237 |
+
predicted_channels,
|
| 238 |
+
show=False
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def _run_tests(self, epoch: int) -> None:
|
| 243 |
+
"""Run tests and log results.
|
| 244 |
+
|
| 245 |
+
Evaluates the model on all test datasets, compares with LS baseline,
|
| 246 |
+
and logs performance metrics and visualizations.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
epoch: Current training epoch
|
| 250 |
+
"""
|
| 251 |
+
ds_stats, mds_stats, snr_stats = get_all_test_stats(
|
| 252 |
+
self.model,
|
| 253 |
+
self.test_loaders,
|
| 254 |
+
self.comparison_loss
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
ls_stats = {
|
| 258 |
+
"DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
|
| 259 |
+
"MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
|
| 260 |
+
"SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
test_stats = {
|
| 264 |
+
"DS": ds_stats,
|
| 265 |
+
"MDS": mds_stats,
|
| 266 |
+
"SNR": snr_stats
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
self._log_test_results(epoch, test_stats, ls_stats)
|
| 270 |
+
|
| 271 |
+
def _log_final_metrics(self, final_epoch: int) -> None:
|
| 272 |
+
"""Log final training metrics and hyperparameters.
|
| 273 |
+
|
| 274 |
+
Records hyperparameters used in training and final performance metrics
|
| 275 |
+
across all test conditions for experiment tracking.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
final_epoch: The index of the final training epoch
|
| 279 |
+
"""
|
| 280 |
+
str_params = {k: str(v) for k, v in vars(self.args).items()}
|
| 281 |
+
self.writer.add_hparams(
|
| 282 |
+
hparam_dict=str_params,
|
| 283 |
+
metric_dict={"last_epoch": final_epoch + 1},
|
| 284 |
+
run_name="."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
try:
|
| 288 |
+
for key in ("DS", "MDS", "SNR"):
|
| 289 |
+
ds_stats, mds_stats, snr_stats = get_all_test_stats(
|
| 290 |
+
self.model,
|
| 291 |
+
self.test_loaders,
|
| 292 |
+
self.comparison_loss
|
| 293 |
+
)
|
| 294 |
+
ls_stats = {
|
| 295 |
+
"DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
|
| 296 |
+
"MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
|
| 297 |
+
"SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if key == "DS":
|
| 301 |
+
stats = ds_stats
|
| 302 |
+
elif key == "MDS":
|
| 303 |
+
stats = mds_stats
|
| 304 |
+
else:
|
| 305 |
+
stats = snr_stats
|
| 306 |
+
|
| 307 |
+
for val in stats.keys():
|
| 308 |
+
self.writer.add_scalars(
|
| 309 |
+
key,
|
| 310 |
+
{
|
| 311 |
+
"LS": ls_stats[key][val],
|
| 312 |
+
self.config["model_name"]: stats[val]
|
| 313 |
+
},
|
| 314 |
+
val
|
| 315 |
+
)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
| 318 |
+
|
| 319 |
+
def train(self) -> None:
|
| 320 |
+
"""Execute the training loop.
|
| 321 |
+
|
| 322 |
+
Runs the complete training process including:
|
| 323 |
+
- Training and validation for each epoch
|
| 324 |
+
- Periodic testing based on test_every_n
|
| 325 |
+
- Early stopping when validation loss plateaus
|
| 326 |
+
- Logging final metrics and results
|
| 327 |
+
"""
|
| 328 |
+
try:
|
| 329 |
+
from tqdm import tqdm
|
| 330 |
+
use_tqdm = True
|
| 331 |
+
except ImportError:
|
| 332 |
+
use_tqdm = False
|
| 333 |
+
print("tqdm not found, progress bar will not be displayed")
|
| 334 |
+
|
| 335 |
+
epoch = None
|
| 336 |
+
|
| 337 |
+
# Create progress bar if tqdm is available
|
| 338 |
+
if use_tqdm:
|
| 339 |
+
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
| 340 |
+
else:
|
| 341 |
+
pbar = range(self.args.max_epoch)
|
| 342 |
+
|
| 343 |
+
for epoch in pbar:
|
| 344 |
+
# Training step
|
| 345 |
+
train_loss = train_epoch(
|
| 346 |
+
self.model,
|
| 347 |
+
self.optimizer,
|
| 348 |
+
self.training_loss,
|
| 349 |
+
self.scheduler,
|
| 350 |
+
self.train_loader
|
| 351 |
+
)
|
| 352 |
+
self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
|
| 353 |
+
|
| 354 |
+
# Validation step
|
| 355 |
+
val_loss = eval_model(self.model, self.val_loader, self.training_loss)
|
| 356 |
+
self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
|
| 357 |
+
|
| 358 |
+
# Update progress bar with loss info if tqdm is available
|
| 359 |
+
if use_tqdm:
|
| 360 |
+
pbar.set_description(
|
| 361 |
+
f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
| 362 |
+
|
| 363 |
+
if self.early_stopper.early_stop(val_loss):
|
| 364 |
+
if use_tqdm:
|
| 365 |
+
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
| 366 |
+
else:
|
| 367 |
+
print(f"Early stopping triggered at epoch {epoch + 1}")
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
# Periodic testing
|
| 371 |
+
if (epoch + 1) % self.args.test_every_n == 0:
|
| 372 |
+
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 373 |
+
if use_tqdm:
|
| 374 |
+
pbar.write(message)
|
| 375 |
+
else:
|
| 376 |
+
print(message)
|
| 377 |
+
self._run_tests(epoch)
|
| 378 |
+
|
| 379 |
+
self._log_final_metrics(epoch)
|
| 380 |
+
self.writer.close()
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def train(config: Dict, args: TrainingArguments) -> None:
|
| 384 |
+
"""
|
| 385 |
+
Train an OFDM channel estimation model.
|
| 386 |
+
|
| 387 |
+
This is the main entry point for model training. It initializes a ModelTrainer
|
| 388 |
+
with the specified configuration and runs the training process.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
config: Model configuration dictionary loaded from YAML file,
|
| 392 |
+
containing model architecture and training parameters
|
| 393 |
+
args: Validated training arguments containing all necessary parameters
|
| 394 |
+
for model training, including dataset paths, hyperparameters,
|
| 395 |
+
and logging configuration
|
| 396 |
+
"""
|
| 397 |
+
trainer = ModelTrainer(config, args)
|
| 398 |
+
trainer.train()
|
src/models/__init__.py
CHANGED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.fortitran import FortiTranEstimator
|
| 2 |
+
from src.models.adafortitran import AdaFortiTranEstimator
|
| 3 |
+
from src.models.linear import LinearEstimator
|
src/utils.py
CHANGED
|
@@ -1,7 +1,70 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def extract_values(file_name):
|
| 6 |
"""
|
| 7 |
Extract channel information from a file name.
|
|
@@ -37,6 +100,58 @@ def extract_values(file_name):
|
|
| 37 |
else:
|
| 38 |
raise ValueError("Cannot extract file information.")
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def concat_complex_channel(channel_matrix):
|
| 41 |
"""
|
| 42 |
Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
|
|
@@ -54,3 +169,170 @@ def concat_complex_channel(channel_matrix):
|
|
| 54 |
imag_channel_m = torch.imag(channel_matrix)
|
| 55 |
cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
|
| 56 |
return cat_channel_m
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for OFDM channel estimation.
|
| 3 |
+
|
| 4 |
+
This module provides various utility functions for processing, visualizing,
|
| 5 |
+
and analyzing OFDM channel estimation data, including complex channel matrices,
|
| 6 |
+
error calculations, model statistics, and visualization tools for
|
| 7 |
+
performance evaluation.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Union
|
| 12 |
import re
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import scipy.io as sio
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from prettytable import PrettyTable
|
| 19 |
import torch
|
| 20 |
|
| 21 |
+
|
| 22 |
+
class EarlyStopping:
|
| 23 |
+
"""Handles early stopping logic for training.
|
| 24 |
+
|
| 25 |
+
Monitors validation loss during training and signals when to stop
|
| 26 |
+
training if the loss has not improved for a specified number of epochs.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
patience: Number of epochs to wait before stopping training
|
| 30 |
+
remaining_patience: Current remaining patience counter
|
| 31 |
+
min_loss: Minimum validation loss observed so far
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, patience: int = 3):
|
| 35 |
+
"""
|
| 36 |
+
Initialize early stopping.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
patience: Number of epochs to wait before stopping
|
| 40 |
+
"""
|
| 41 |
+
self.patience = patience
|
| 42 |
+
self.remaining_patience = patience
|
| 43 |
+
self.min_loss: Optional[float] = None
|
| 44 |
+
|
| 45 |
+
def early_stop(self, loss: float) -> bool:
|
| 46 |
+
"""
|
| 47 |
+
Check if training should stop.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
loss: Current validation loss
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Whether to stop training
|
| 54 |
+
"""
|
| 55 |
+
if self.min_loss is None:
|
| 56 |
+
self.min_loss = loss
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
if loss < self.min_loss:
|
| 60 |
+
self.min_loss = loss
|
| 61 |
+
self.remaining_patience = self.patience
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
self.remaining_patience -= 1
|
| 65 |
+
return self.remaining_patience == 0
|
| 66 |
+
|
| 67 |
+
|
| 68 |
def extract_values(file_name):
|
| 69 |
"""
|
| 70 |
Extract channel information from a file name.
|
|
|
|
| 100 |
else:
|
| 101 |
raise ValueError("Cannot extract file information.")
|
| 102 |
|
| 103 |
+
|
| 104 |
+
def get_error_images(variable, channel_data, show=False):
|
| 105 |
+
"""
|
| 106 |
+
Create visualizations of channel estimation errors.
|
| 107 |
+
|
| 108 |
+
Generates a figure with error heatmaps for different channel conditions,
|
| 109 |
+
showing the absolute difference between estimated and ideal channels.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
variable: Name of the variable being visualized (e.g., 'SNR', 'DS')
|
| 113 |
+
channel_data: Dictionary mapping parameter values to dictionaries
|
| 114 |
+
containing 'estimated_channel' and 'ideal_channel'
|
| 115 |
+
show: Whether to display the figure immediately (default: False)
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
matplotlib.figure.Figure: The generated figure with error heatmaps
|
| 119 |
+
"""
|
| 120 |
+
# Create a figure with 7 subplots
|
| 121 |
+
fig, axes = plt.subplots(1, len(channel_data), figsize=(20, 6))
|
| 122 |
+
|
| 123 |
+
# Plot each subplot with consistent color scaling
|
| 124 |
+
for i, (key, channels) in enumerate(channel_data.items()):
|
| 125 |
+
# Calculate absolute error between estimated and ideal channels
|
| 126 |
+
|
| 127 |
+
estimated_channel = channels['estimated_channel']
|
| 128 |
+
ideal_channel = channels['ideal_channel']
|
| 129 |
+
|
| 130 |
+
error_matrix = torch.abs(estimated_channel - ideal_channel)
|
| 131 |
+
error_numpy = error_matrix.detach().cpu().numpy()
|
| 132 |
+
|
| 133 |
+
# Plot in the corresponding subplot with shared colormap limits
|
| 134 |
+
ax = axes[i]
|
| 135 |
+
cax = ax.imshow(error_numpy, cmap='viridis', aspect=14 / 120, vmin=0, vmax=1)
|
| 136 |
+
ax.set_title(f"{variable} = {key}")
|
| 137 |
+
ax.set_xlabel('Columns (14)')
|
| 138 |
+
ax.set_ylabel('Rows (120)')
|
| 139 |
+
|
| 140 |
+
# Create a new axis for the color bar to the right of the subplots
|
| 141 |
+
cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) # [left, bottom, width, height]
|
| 142 |
+
fig.colorbar(cax, cax=cbar_ax, label='Error Magnitude')
|
| 143 |
+
|
| 144 |
+
# Adjust layout to prevent overlapping labels
|
| 145 |
+
fig.tight_layout(rect=(0, 0, 0.9, 1)) # Leave space for the color bar on the right
|
| 146 |
+
|
| 147 |
+
# Show the figure if `show` is True
|
| 148 |
+
if show:
|
| 149 |
+
plt.show()
|
| 150 |
+
|
| 151 |
+
# Return the main figure
|
| 152 |
+
return fig
|
| 153 |
+
|
| 154 |
+
|
| 155 |
def concat_complex_channel(channel_matrix):
|
| 156 |
"""
|
| 157 |
Convert a complex channel matrix into a real matrix by concatenating real and imaginary parts.
|
|
|
|
| 169 |
imag_channel_m = torch.imag(channel_matrix)
|
| 170 |
cat_channel_m = torch.cat((real_channel_m, imag_channel_m), dim=1)
|
| 171 |
return cat_channel_m
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def inverse_concat_complex_channel(channel_matrix: torch.Tensor) -> torch.Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Reconstruct complex channel matrix from concatenated real matrix.
|
| 177 |
+
|
| 178 |
+
Reverses the operation performed by concat_complex_channel by
|
| 179 |
+
splitting the tensor and combining the parts into a complex tensor.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
channel_matrix: Real-valued matrix of shape (B, F, 2*T)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Complex matrix of shape (B, F, T)
|
| 186 |
+
"""
|
| 187 |
+
split_idx = channel_matrix.shape[-1] // 2
|
| 188 |
+
return torch.complex(
|
| 189 |
+
channel_matrix[:, :split_idx],
|
| 190 |
+
channel_matrix[:, split_idx:]
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_test_stats_plot(x_name, stats, methods, show=False):
|
| 195 |
+
"""
|
| 196 |
+
Plot test statistics for multiple methods as line graphs.
|
| 197 |
+
|
| 198 |
+
Creates a line plot comparing performance metrics (e.g., MSE) across
|
| 199 |
+
different conditions or parameters for multiple methods.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x_name: Label for the x-axis (e.g., 'SNR', 'DS', 'Epoch')
|
| 203 |
+
stats: List of dictionaries where each dictionary maps x-values to
|
| 204 |
+
performance metrics for a specific method
|
| 205 |
+
methods: List of method names corresponding to each entry in stats
|
| 206 |
+
show: Whether to display the plot immediately (default: False)
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
matplotlib.figure.Figure: The generated figure object
|
| 210 |
+
|
| 211 |
+
Raises:
|
| 212 |
+
AssertionError: If stats and methods lists have different lengths
|
| 213 |
+
"""
|
| 214 |
+
assert len(stats) == len(methods), "Provided stats and methods do not have the same length."
|
| 215 |
+
fig = plt.figure()
|
| 216 |
+
symbols = iter(["*", "x", "+", "D", "v", "^"])
|
| 217 |
+
for stat in stats:
|
| 218 |
+
try:
|
| 219 |
+
symbol = next(symbols)
|
| 220 |
+
except StopIteration:
|
| 221 |
+
symbols = iter(["o", "*", "x", "+", "D", "v", "^"])
|
| 222 |
+
symbol = next(symbols)
|
| 223 |
+
|
| 224 |
+
kv_pairs = sorted(list(stat.items()), key=lambda x: x[0])
|
| 225 |
+
x_vals = []
|
| 226 |
+
y_vals = []
|
| 227 |
+
for key, value in kv_pairs:
|
| 228 |
+
x_vals.append(key)
|
| 229 |
+
y_vals.append(value)
|
| 230 |
+
|
| 231 |
+
plt.plot(x_vals, y_vals, f"{symbol}--")
|
| 232 |
+
plt.xlabel(x_name)
|
| 233 |
+
plt.ylabel("MSE Error (dB)")
|
| 234 |
+
plt.grid()
|
| 235 |
+
plt.legend(methods)
|
| 236 |
+
if show:
|
| 237 |
+
plt.show()
|
| 238 |
+
return fig
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def to_db(val):
|
| 242 |
+
"""
|
| 243 |
+
Convert values to decibels (dB).
|
| 244 |
+
|
| 245 |
+
Applies the formula 10 * log10(val) to convert values to the decibel scale.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
val: Input value or array to convert to dB (must be positive)
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
The input value(s) converted to decibels
|
| 252 |
+
"""
|
| 253 |
+
return 10 * np.log10(val)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def mse(x, y):
|
| 257 |
+
"""
|
| 258 |
+
Calculate mean squared error (MSE) in dB between two complex arrays.
|
| 259 |
+
|
| 260 |
+
Computes the average squared magnitude of the difference between
|
| 261 |
+
two complex arrays and converts the result to decibels.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
x: First complex numpy array
|
| 265 |
+
y: Second complex numpy array (same shape as x)
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
MSE in decibels (dB) between the two arrays
|
| 269 |
+
"""
|
| 270 |
+
mse_xy = np.mean(np.square(np.abs(x - y)))
|
| 271 |
+
mse_xy_db = to_db(mse_xy)
|
| 272 |
+
return mse_xy_db
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def get_ls_mse_per_folder(folders_dir: Union[Path, str]):
|
| 276 |
+
"""
|
| 277 |
+
Calculate average MSE for LS estimates in each subfolder.
|
| 278 |
+
|
| 279 |
+
For each subfolder in the specified directory, calculates the average
|
| 280 |
+
mean squared error between least-squares channel estimates and ideal
|
| 281 |
+
channel values across all .mat files in that subfolder.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
folders_dir: Path to directory containing subfolders with .mat files
|
| 285 |
+
Each subfolder should be named 'prefix_val' where val is an integer
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Dictionary mapping integer values from subfolder names to average MSE values in dB
|
| 289 |
+
|
| 290 |
+
Notes:
|
| 291 |
+
- Each .mat file should contain a 3D matrix 'H' where:
|
| 292 |
+
- H[:,:,0] is the ideal channel
|
| 293 |
+
- H[:,:,2] is the LS channel estimate
|
| 294 |
+
- Subfolders are sorted by the integer in their names
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
mse_sums = {}
|
| 298 |
+
folders = os.listdir(folders_dir)
|
| 299 |
+
folders = sorted(folders, key=lambda x: int(x.split("_")[1]))
|
| 300 |
+
for folder in folders:
|
| 301 |
+
_, val = folder.split("_")
|
| 302 |
+
mse_sum = 0
|
| 303 |
+
folder_size = len(os.listdir(os.path.join(folders_dir, folder)))
|
| 304 |
+
for file in os.listdir(os.path.join(folders_dir, folder)):
|
| 305 |
+
mat_data = sio.loadmat(os.path.join(folders_dir, folder, file))['H']
|
| 306 |
+
ls_estimate = mat_data[:, :, 2]
|
| 307 |
+
ideal = mat_data[:, :, 0]
|
| 308 |
+
mse_sum += mse(ls_estimate, ideal)
|
| 309 |
+
mse_sum /= folder_size
|
| 310 |
+
mse_sums[int(val)] = mse_sum
|
| 311 |
+
return mse_sums
|
| 312 |
+
|
| 313 |
+
def get_model_details(model):
|
| 314 |
+
"""
|
| 315 |
+
Get parameter counts and structure details for a PyTorch model.
|
| 316 |
+
|
| 317 |
+
Analyzes a PyTorch model to determine the total number of trainable
|
| 318 |
+
parameters and creates a formatted table showing the parameter count
|
| 319 |
+
for each named parameter in the model.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
model: PyTorch model to analyze
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
tuple containing:
|
| 326 |
+
- total_params: Total number of trainable parameters
|
| 327 |
+
- table: PrettyTable showing parameter counts by module
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
table = PrettyTable(["Modules", "Parameters"])
|
| 331 |
+
total_params = 0
|
| 332 |
+
for name, parameter in model.named_parameters():
|
| 333 |
+
if not parameter.requires_grad:
|
| 334 |
+
continue
|
| 335 |
+
params = parameter.numel()
|
| 336 |
+
table.add_row([name, params])
|
| 337 |
+
total_params += params
|
| 338 |
+
return total_params, table
|