BerkIGuler commited on
Commit
4e938bd
·
1 Parent(s): 54d5c08

refactored trainer class

Browse files
.gitignore CHANGED
@@ -1,2 +1,2 @@
1
  .idea/
2
-
 
1
  .idea/
2
+ +**/__pycache__/
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
  pydantic
3
  yaml
4
- scipy
 
 
1
  torch
2
  pydantic
3
  yaml
4
+ scipy
5
+ tqdm
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
src/config/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (246 Bytes). View file
 
src/config/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (9.4 kB). View file
 
src/config/config_loader.py CHANGED
@@ -21,10 +21,8 @@ class ConfigLoader:
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
  model_config_path: Path to YAML configuration file for model-related parameters
23
 
24
-
25
  Returns:
26
- ModelConfig: Validated model configuration object
27
- SystemConfig: Validated system configuration object
28
 
29
  Raises:
30
  FileNotFoundError: If one of the config files doesn't exist
@@ -34,10 +32,10 @@ class ConfigLoader:
34
  model_config_path = Path(model_config_path)
35
 
36
  if not system_config_path.exists():
37
- raise FileNotFoundError(f"Configuration file not found: {system_config_path}")
38
 
39
  if not model_config_path.exists():
40
- raise FileNotFoundError(f"Configuration file not found: {model_config_path}")
41
 
42
  try:
43
  with open(system_config_path, 'r') as f:
@@ -55,16 +53,15 @@ class ConfigLoader:
55
  system_config = SystemConfig(**system_raw_config)
56
  self.logger.info(f"Successfully loaded system config from {system_config_path}")
57
  except ValidationError as e:
58
- raise ValueError(f"Configuration validation for {system_config_path} failed:\n{e}")
59
- if system_config:
60
- try:
61
- model_config = ModelConfig(system_config, **model_raw_config)
62
- self.logger.info(f"Successfully loaded model config from {model_config_path}")
63
- except ValidationError as e:
64
- raise ValueError(f"Configuration validation for {model_config_path} failed:\n{e}")
65
 
66
- return system_config, model_config
 
 
 
 
67
 
 
68
 
69
 
70
  def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
 
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
  model_config_path: Path to YAML configuration file for model-related parameters
23
 
 
24
  Returns:
25
+ Tuple of (SystemConfig, ModelConfig): Validated configuration objects
 
26
 
27
  Raises:
28
  FileNotFoundError: If one of the config files doesn't exist
 
32
  model_config_path = Path(model_config_path)
33
 
34
  if not system_config_path.exists():
35
+ raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
36
 
37
  if not model_config_path.exists():
38
+ raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
39
 
40
  try:
41
  with open(system_config_path, 'r') as f:
 
53
  system_config = SystemConfig(**system_raw_config)
54
  self.logger.info(f"Successfully loaded system config from {system_config_path}")
55
  except ValidationError as e:
56
+ raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
 
 
 
 
 
 
57
 
58
+ try:
59
+ model_config = ModelConfig(**model_raw_config)
60
+ self.logger.info(f"Successfully loaded model config from {model_config_path}")
61
+ except ValidationError as e:
62
+ raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
63
 
64
+ return system_config, model_config
65
 
66
 
67
  def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
src/config/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
  from pydantic import BaseModel, Field, model_validator
2
- from typing import Self, Tuple
3
  import torch
4
 
5
 
@@ -14,8 +14,19 @@ class PilotParams(BaseModel):
14
 
15
 
16
  class ModelParams(BaseModel):
17
- patch_size: Tuple[int, int] = Field(default=(10, 4), description="Patch size as (height, width)")
18
- num_layers: int = Field(default=6, gt=0, description="Number of model layers")
 
 
 
 
 
 
 
 
 
 
 
19
  device: str = Field(default="cpu", description="Device to use")
20
 
21
  @model_validator(mode='after')
@@ -103,39 +114,79 @@ class SystemConfig(BaseModel):
103
 
104
 
105
  class ModelConfig(BaseModel):
106
- system: SystemConfig
107
- model: ModelParams
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @model_validator(mode='after')
110
- def validate_patch_constraints(self) -> Self:
111
- """Ensure patch size is compatible with OFDM dimensions."""
112
- patch_height, patch_width = self.model.patch_size
113
 
114
- if patch_height > self.system.ofdm.num_symbols:
115
- raise ValueError(
116
- f"Patch height ({patch_height}) cannot exceed "
117
- f"OFDM symbols ({self.system.ofdm.num_symbols})"
118
- )
 
 
 
 
119
 
120
- if patch_width > self.system.ofdm.num_scs:
121
- raise ValueError(
122
- f"Patch width ({patch_width}) cannot exceed "
123
- f"OFDM sub-carriers ({self.system.ofdm.num_scs})"
124
- )
125
 
126
- # Check if OFDM dimensions are divisible by patch size for clean patching
127
- if self.system.ofdm.num_symbols % patch_height != 0:
128
- raise ValueError(
129
- f"OFDM symbols ({self.system.ofdm.num_symbols}) must be divisible "
130
- f"by patch height ({patch_height}) for clean patching"
131
- )
132
 
133
- if self.system.ofdm.num_scs % patch_width != 0:
134
- raise ValueError(
135
- f"OFDM sub-carriers ({self.system.ofdm.num_scs}) must be divisible "
136
- f"by patch width ({patch_width}) for clean patching"
137
- )
 
 
 
 
 
 
 
 
 
138
 
139
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  model_config = {"extra": "forbid"}
 
1
  from pydantic import BaseModel, Field, model_validator
2
+ from typing import Self, Tuple, List, Optional
3
  import torch
4
 
5
 
 
14
 
15
 
16
  class ModelParams(BaseModel):
17
+ patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
18
+ num_layers: int = Field(..., gt=0, description="Number of transformer layers")
19
+ model_dim: int = Field(..., gt=0, description="Model dimension")
20
+ num_head: int = Field(..., gt=0, description="Number of attention heads")
21
+ activation: str = Field(default="gelu", description="Activation function")
22
+ dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
23
+ max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
24
+ pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
25
+ adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
26
+ channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
27
+ default=None,
28
+ description="Hidden sizes for channel adaptation layers"
29
+ )
30
  device: str = Field(default="cpu", description="Device to use")
31
 
32
  @model_validator(mode='after')
 
114
 
115
 
116
  class ModelConfig(BaseModel):
117
+ patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
118
+ num_layers: int = Field(..., gt=0, description="Number of transformer layers")
119
+ model_dim: int = Field(..., gt=0, description="Model dimension")
120
+ num_head: int = Field(..., gt=0, description="Number of attention heads")
121
+ activation: str = Field(default="gelu", description="Activation function")
122
+ dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
123
+ max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
124
+ pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
125
+ adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
126
+ channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
127
+ default=None,
128
+ description="Hidden sizes for channel adaptation layers"
129
+ )
130
+ device: str = Field(default="cpu", description="Device to use")
131
 
132
  @model_validator(mode='after')
133
+ def validate_device(self) -> Self:
134
+ """Validate that the specified device is available."""
135
+ device_str = self.device.lower()
136
 
137
+ # Handle 'auto' case - automatically select best available device
138
+ if device_str == 'auto':
139
+ if torch.cuda.is_available():
140
+ self.device = 'cuda'
141
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
142
+ self.device = 'mps' # Apple Silicon
143
+ else:
144
+ self.device = 'cpu'
145
+ return self
146
 
147
+ # Validate CPU
148
+ if device_str == 'cpu':
149
+ return self
 
 
150
 
151
+ # Validate CUDA devices
152
+ if device_str.startswith('cuda'):
153
+ if not torch.cuda.is_available():
154
+ raise ValueError("CUDA is not available on this system")
 
 
155
 
156
+ # Handle specific CUDA device (e.g., 'cuda:0', 'cuda:1')
157
+ if ':' in device_str:
158
+ try:
159
+ device_id = int(device_str.split(':')[1])
160
+ if device_id >= torch.cuda.device_count():
161
+ available_devices = list(range(torch.cuda.device_count()))
162
+ raise ValueError(
163
+ f"CUDA device {device_id} not available. "
164
+ f"Available CUDA devices: {available_devices}"
165
+ )
166
+ except (ValueError, IndexError) as e:
167
+ if "invalid literal" in str(e):
168
+ raise ValueError(f"Invalid CUDA device format: {device_str}")
169
+ raise
170
 
171
+ return self
172
+
173
+ # Validate MPS (Apple Silicon)
174
+ if device_str == 'mps':
175
+ if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
176
+ raise ValueError("MPS is not available on this system")
177
+ return self
178
+
179
+ # If we get here, the device is not recognized
180
+ available_devices = ['cpu']
181
+ if torch.cuda.is_available():
182
+ cuda_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
183
+ available_devices.extend(['cuda'] + cuda_devices)
184
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
185
+ available_devices.append('mps')
186
+
187
+ raise ValueError(
188
+ f"Unsupported device: '{self.device}'. "
189
+ f"Available devices: {available_devices}"
190
+ )
191
 
192
  model_config = {"extra": "forbid"}
src/main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main entry point for OFDM channel estimation model training.
4
+
5
+ This script provides the command-line interface for training OFDM channel estimation
6
+ models. It loads configuration files, parses command-line arguments, and initiates
7
+ the training process.
8
+ """
9
+
10
+ import logging
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ from src.main.parser import parse_arguments
15
+ from src.main.trainer import train
16
+ from src.config.config_loader import load_config
17
+
18
+
19
+ def setup_logging(log_level: str) -> None:
20
+ """Set up logging configuration.
21
+
22
+ Args:
23
+ log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
24
+ """
25
+ logging.basicConfig(
26
+ level=getattr(logging, log_level.upper()),
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
28
+ handlers=[
29
+ logging.StreamHandler(sys.stdout),
30
+ logging.FileHandler('training.log')
31
+ ]
32
+ )
33
+
34
+
35
+ def main() -> None:
36
+ """Main entry point for the training script."""
37
+ try:
38
+ # Parse command-line arguments
39
+ args = parse_arguments()
40
+
41
+ # Set up logging
42
+ setup_logging(args.python_log_level)
43
+ logger = logging.getLogger(__name__)
44
+
45
+ logger.info("Starting OFDM channel estimation model training")
46
+ logger.info(f"Model: {args.model_name}")
47
+ logger.info(f"System config: {args.system_config_path}")
48
+ logger.info(f"Model config: {args.model_config_path}")
49
+ logger.info(f"Experiment ID: {args.exp_id}")
50
+
51
+ # Load and validate configurations
52
+ logger.info("Loading configuration files...")
53
+ system_config, model_config = load_config(
54
+ args.system_config_path,
55
+ args.model_config_path
56
+ )
57
+
58
+ logger.info("Configuration loaded successfully")
59
+ logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
60
+ logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
61
+ logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
62
+
63
+ # Start training
64
+ logger.info("Initializing training...")
65
+ train(system_config, model_config, args)
66
+
67
+ logger.info("Training completed successfully")
68
+
69
+ except KeyboardInterrupt:
70
+ logger.info("Training interrupted by user")
71
+ sys.exit(1)
72
+ except Exception as e:
73
+ logger.error(f"Training failed with error: {str(e)}")
74
+ sys.exit(1)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
src/main/parser.py CHANGED
@@ -10,6 +10,14 @@ of training runs.
10
  from dataclasses import dataclass
11
  from pathlib import Path
12
  import argparse
 
 
 
 
 
 
 
 
13
 
14
 
15
  @dataclass
@@ -23,6 +31,7 @@ class TrainingArguments:
23
  # Model Configuration
24
  model_name: Supports Linear, AdaFortiTran, or FortiTran training
25
  system_config_path: Path to OFDM system configuration file
 
26
 
27
  # Dataset Paths
28
  train_set: Path to training dataset directory
@@ -39,6 +48,8 @@ class TrainingArguments:
39
  lr: Learning rate for optimizer
40
  max_epoch: Maximum number of training epochs
41
  patience: Early stopping patience in epochs
 
 
42
 
43
  # Hardware & Evaluation
44
  cuda: CUDA device index
@@ -48,6 +59,7 @@ class TrainingArguments:
48
  # Model Configuration
49
  model_name: str
50
  system_config_path: Path
 
51
 
52
  # Dataset Paths
53
  train_set: Path
@@ -64,6 +76,8 @@ class TrainingArguments:
64
  lr: float = 1e-3
65
  max_epoch: int = 10
66
  patience: int = 3
 
 
67
 
68
  # Hardware & Evaluation
69
  cuda: int = 0
@@ -84,16 +98,22 @@ class TrainingArguments:
84
  def _validate_paths(self) -> None:
85
  """Validate path-related arguments.
86
 
87
- Checks that the config file exists and has the correct extension.
88
 
89
  Raises:
90
- ValueError: If the config file doesn't exist or isn't a YAML file
91
  """
92
  if not self.system_config_path.exists():
93
- raise ValueError(f"Config file not found: {self.system_config_path}")
94
 
95
  if not self.system_config_path.suffix == '.yaml':
96
- raise ValueError(f"Config file must be a .yaml file: {self.system_config_path}")
 
 
 
 
 
 
97
 
98
  def _validate_numeric_args(self) -> None:
99
  """Validate numeric arguments.
@@ -159,6 +179,12 @@ def parse_arguments() -> TrainingArguments:
159
  required=True,
160
  help='Path to YAML file containing OFDM system parameters'
161
  )
 
 
 
 
 
 
162
  required.add_argument(
163
  '--train_set',
164
  type=Path,
@@ -234,8 +260,25 @@ def parse_arguments() -> TrainingArguments:
234
  default=1e-3,
235
  help='Initial learning rate'
236
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  args = parser.parse_args()
239
 
 
 
 
240
  # Create and validate TrainingArguments
241
  return TrainingArguments(**vars(args))
 
10
  from dataclasses import dataclass
11
  from pathlib import Path
12
  import argparse
13
+ from enum import Enum
14
+
15
+
16
+ class LossType(Enum):
17
+ """Enumeration of supported loss functions."""
18
+ MSE = "mse"
19
+ MAE = "mae"
20
+ HUBER = "huber"
21
 
22
 
23
  @dataclass
 
31
  # Model Configuration
32
  model_name: Supports Linear, AdaFortiTran, or FortiTran training
33
  system_config_path: Path to OFDM system configuration file
34
+ model_config_path: Path to model configuration file
35
 
36
  # Dataset Paths
37
  train_set: Path to training dataset directory
 
48
  lr: Learning rate for optimizer
49
  max_epoch: Maximum number of training epochs
50
  patience: Early stopping patience in epochs
51
+ loss_type: Type of loss function to use
52
+ return_type: Type of data to return from dataset
53
 
54
  # Hardware & Evaluation
55
  cuda: CUDA device index
 
59
  # Model Configuration
60
  model_name: str
61
  system_config_path: Path
62
+ model_config_path: Path
63
 
64
  # Dataset Paths
65
  train_set: Path
 
76
  lr: float = 1e-3
77
  max_epoch: int = 10
78
  patience: int = 3
79
+ loss_type: LossType = LossType.MSE
80
+ return_type: str = "complex"
81
 
82
  # Hardware & Evaluation
83
  cuda: int = 0
 
98
  def _validate_paths(self) -> None:
99
  """Validate path-related arguments.
100
 
101
+ Checks that the config files exist and have the correct extension.
102
 
103
  Raises:
104
+ ValueError: If the config files don't exist or aren't YAML files
105
  """
106
  if not self.system_config_path.exists():
107
+ raise ValueError(f"System config file not found: {self.system_config_path}")
108
 
109
  if not self.system_config_path.suffix == '.yaml':
110
+ raise ValueError(f"System config file must be a .yaml file: {self.system_config_path}")
111
+
112
+ if not self.model_config_path.exists():
113
+ raise ValueError(f"Model config file not found: {self.model_config_path}")
114
+
115
+ if not self.model_config_path.suffix == '.yaml':
116
+ raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
117
 
118
  def _validate_numeric_args(self) -> None:
119
  """Validate numeric arguments.
 
179
  required=True,
180
  help='Path to YAML file containing OFDM system parameters'
181
  )
182
+ required.add_argument(
183
+ '--model_config_path',
184
+ type=Path,
185
+ required=True,
186
+ help='Path to YAML file containing model architecture parameters'
187
+ )
188
  required.add_argument(
189
  '--train_set',
190
  type=Path,
 
260
  default=1e-3,
261
  help='Initial learning rate'
262
  )
263
+ optional.add_argument(
264
+ '--loss_type',
265
+ type=str,
266
+ default="mse",
267
+ choices=['mse', 'mae', 'huber'],
268
+ help='Loss function type'
269
+ )
270
+ optional.add_argument(
271
+ '--return_type',
272
+ type=str,
273
+ default="complex",
274
+ choices=['complex', 'real'],
275
+ help='Type of data to return from dataset'
276
+ )
277
 
278
  args = parser.parse_args()
279
 
280
+ # Convert loss_type string to enum
281
+ args.loss_type = LossType(args.loss_type)
282
+
283
  # Create and validate TrainingArguments
284
  return TrainingArguments(**vars(args))
src/main/trainer.py CHANGED
@@ -10,8 +10,10 @@ training loop management, evaluation, and result logging.
10
  import torch
11
  from torch import nn, optim
12
  from torch.utils.data import DataLoader
13
- from torch.utils.tensorboard import SummaryWriter
14
  from typing import Dict, Tuple, Type, Union
 
 
15
 
16
  from .parser import TrainingArguments
17
  from src.data.dataset import MatDataset, get_test_dataloaders
@@ -21,14 +23,11 @@ from src.utils import (
21
  get_ls_mse_per_folder,
22
  get_model_details,
23
  get_test_stats_plot,
24
- get_error_images
25
- )
26
- from src.main.train_helpers import (
27
- get_all_test_stats,
28
- train_epoch,
29
- eval_model,
30
- predict_channels
31
  )
 
32
 
33
  # A union type representing supported model classes
34
  ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
@@ -48,16 +47,18 @@ class ModelTrainer:
48
  Attributes:
49
  MODEL_REGISTRY: Dictionary mapping model names to model classes
50
  system_config: OFDM system configuration
51
- args: Training arguments
 
52
  device: PyTorch device for computation
53
  writer: TensorBoard SummaryWriter for logging
54
- model: Initialized model instance
55
  optimizer: Torch optimizer for training
56
- scheduler: Learning rate scheduler
57
  early_stopper: Helper for early stopping
58
- train_loader: DataLoader for training set
59
- val_loader: DataLoader for validation set
60
- test_loaders: Dictionary of test set DataLoaders
 
61
  """
62
 
63
  MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
@@ -66,47 +67,33 @@ class ModelTrainer:
66
  "fortitran": FortiTranEstimator,
67
  }
68
 
69
- def __init__(self, system_config: Dict, args: TrainingArguments):
 
 
70
  """
71
  Initialize the ModelTrainer.
72
 
73
  Args:
74
- config: Model configuration dictionary from YAML file
75
- args: Validated training arguments
 
76
  """
77
  self.system_config = system_config
 
78
  self.args = args
79
  self.device = torch.device(f"cuda:{args.cuda}")
80
  self.writer = self._setup_tensorboard()
 
81
 
82
  self.model = self._initialize_model()
83
  self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
84
- self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.995)
85
  self.early_stopper = EarlyStopping(patience=args.patience)
86
 
87
- self.training_loss = self._get_loss_function()
88
- self.comparison_loss = nn.MSELoss() # used for test set evaluation
89
 
90
  self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
91
 
92
- def _get_loss_function(self) -> nn.Module:
93
- """Get the appropriate loss function based on arguments.
94
-
95
- Returns:
96
- The selected PyTorch loss function based on args.loss_type
97
-
98
- Raises:
99
- ValueError: If an unsupported loss type is specified
100
- """
101
- if self.args.loss_type == LossType.MSE:
102
- return nn.MSELoss()
103
- elif self.args.loss_type == LossType.MAE:
104
- return nn.L1Loss()
105
- elif self.args.loss_type == LossType.HUBER:
106
- return nn.HuberLoss()
107
- else:
108
- raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
109
-
110
  def _setup_tensorboard(self) -> SummaryWriter:
111
  """Set up TensorBoard logging.
112
 
@@ -134,38 +121,30 @@ class ModelTrainer:
134
  Initialized model instance of the specified type
135
  """
136
  model_class = self.MODEL_REGISTRY[self.args.model_name]
137
- model = model_class(self.device, self.config, vars(self.args))
138
-
 
 
 
 
139
  num_params, model_summary = get_model_details(model)
140
- print(model_summary)
141
- print(f"Model name: {self.config['model_name']}\nNumber of parameters: {num_params}")
 
142
  self.writer.add_text("Number of Parameters", str(num_params))
143
-
144
  return model
145
 
146
  def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
147
- """Initialize all required dataloaders.
148
-
149
- Creates DataLoader instances for:
150
- - Training dataset
151
- - Validation dataset
152
- - Test datasets grouped by test condition (DS, MDS, SNR)
153
-
154
- Returns:
155
- Tuple containing (train_loader, val_loader, test_loaders_dict)
156
- """
157
  # Training and validation dataloaders
158
  train_dataset = MatDataset(
159
  self.args.train_set,
160
- self.args.pilot_dims,
161
- return_type=self.config["return_type"]
162
  )
163
  val_dataset = MatDataset(
164
  self.args.val_set,
165
- self.args.pilot_dims,
166
- return_type=self.config["return_type"]
167
  )
168
-
169
  train_loader = DataLoader(
170
  train_dataset,
171
  batch_size=self.args.batch_size,
@@ -176,43 +155,34 @@ class ModelTrainer:
176
  batch_size=self.args.batch_size,
177
  shuffle=True
178
  )
179
-
180
- # Test dataloaders
181
  test_loaders = {
182
  "DS": get_test_dataloaders(
183
  self.args.test_set / "DS_test_set",
184
- vars(self.args),
185
- self.config["return_type"]
186
  ),
187
  "MDS": get_test_dataloaders(
188
  self.args.test_set / "MDS_test_set",
189
- vars(self.args),
190
- self.config["return_type"]
191
  ),
192
  "SNR": get_test_dataloaders(
193
  self.args.test_set / "SNR_test_set",
194
- vars(self.args),
195
- self.config["return_type"]
196
  ),
197
  }
198
-
199
  return train_loader, val_loader, test_loaders
200
 
201
  def _log_test_results(
202
  self,
203
  epoch: int,
204
- test_stats: Dict[str, Dict],
205
- ls_stats: Dict[str, Dict]
206
  ) -> None:
207
  """Log test results to TensorBoard.
208
 
209
- Creates and logs visualizations comparing model performance against
210
- baseline LS estimator across different test conditions.
211
 
212
  Args:
213
  epoch: Current training epoch
214
  test_stats: Dictionary of test statistics for the model
215
- ls_stats: Dictionary of test statistics for the LS baseline
216
  """
217
  for key in ("DS", "MDS", "SNR"):
218
  # Plot test statistics
@@ -220,16 +190,13 @@ class ModelTrainer:
220
  tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
221
  figure=get_test_stats_plot(
222
  x_name=key,
223
- stats=[test_stats[key], ls_stats[key]],
224
- methods=[self.config["model_name"], "LS"]
225
  )
226
  )
227
 
228
  # Plot error images
229
- predicted_channels = predict_channels(
230
- self.model,
231
- self.test_loaders[key]
232
- )
233
  self.writer.add_figure(
234
  tag=f"{key} Error Images (Epoch:{epoch + 1})",
235
  figure=get_error_images(
@@ -242,23 +209,12 @@ class ModelTrainer:
242
  def _run_tests(self, epoch: int) -> None:
243
  """Run tests and log results.
244
 
245
- Evaluates the model on all test datasets, compares with LS baseline,
246
- and logs performance metrics and visualizations.
247
 
248
  Args:
249
  epoch: Current training epoch
250
  """
251
- ds_stats, mds_stats, snr_stats = get_all_test_stats(
252
- self.model,
253
- self.test_loaders,
254
- self.comparison_loss
255
- )
256
-
257
- ls_stats = {
258
- "DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
259
- "MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
260
- "SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
261
- }
262
 
263
  test_stats = {
264
  "DS": ds_stats,
@@ -266,7 +222,7 @@ class ModelTrainer:
266
  "SNR": snr_stats
267
  }
268
 
269
- self._log_test_results(epoch, test_stats, ls_stats)
270
 
271
  def _log_final_metrics(self, final_epoch: int) -> None:
272
  """Log final training metrics and hyperparameters.
@@ -286,11 +242,7 @@ class ModelTrainer:
286
 
287
  try:
288
  for key in ("DS", "MDS", "SNR"):
289
- ds_stats, mds_stats, snr_stats = get_all_test_stats(
290
- self.model,
291
- self.test_loaders,
292
- self.comparison_loss
293
- )
294
  ls_stats = {
295
  "DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
296
  "MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
@@ -309,13 +261,95 @@ class ModelTrainer:
309
  key,
310
  {
311
  "LS": ls_stats[key][val],
312
- self.config["model_name"]: stats[val]
313
  },
314
  val
315
  )
316
  except Exception as e:
317
  self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def train(self) -> None:
320
  """Execute the training loop.
321
 
@@ -325,62 +359,35 @@ class ModelTrainer:
325
  - Early stopping when validation loss plateaus
326
  - Logging final metrics and results
327
  """
328
- try:
329
- from tqdm import tqdm
330
- use_tqdm = True
331
- except ImportError:
332
- use_tqdm = False
333
- print("tqdm not found, progress bar will not be displayed")
334
-
335
  epoch = None
336
-
337
- # Create progress bar if tqdm is available
338
- if use_tqdm:
339
- pbar = tqdm(range(self.args.max_epoch), desc="Training")
340
- else:
341
- pbar = range(self.args.max_epoch)
342
-
343
  for epoch in pbar:
344
  # Training step
345
- train_loss = train_epoch(
346
- self.model,
347
- self.optimizer,
348
- self.training_loss,
349
- self.scheduler,
350
- self.train_loader
351
- )
352
  self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
353
 
354
  # Validation step
355
- val_loss = eval_model(self.model, self.val_loader, self.training_loss)
356
  self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
357
 
358
- # Update progress bar with loss info if tqdm is available
359
- if use_tqdm:
360
- pbar.set_description(
361
- f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
362
 
363
  if self.early_stopper.early_stop(val_loss):
364
- if use_tqdm:
365
- pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
366
- else:
367
- print(f"Early stopping triggered at epoch {epoch + 1}")
368
  break
369
 
370
  # Periodic testing
371
  if (epoch + 1) % self.args.test_every_n == 0:
372
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
373
- if use_tqdm:
374
- pbar.write(message)
375
- else:
376
- print(message)
377
  self._run_tests(epoch)
378
-
379
  self._log_final_metrics(epoch)
380
  self.writer.close()
381
 
382
 
383
- def train(config: Dict, args: TrainingArguments) -> None:
384
  """
385
  Train an OFDM channel estimation model.
386
 
@@ -388,11 +395,11 @@ def train(config: Dict, args: TrainingArguments) -> None:
388
  with the specified configuration and runs the training process.
389
 
390
  Args:
391
- config: Model configuration dictionary loaded from YAML file,
392
- containing model architecture and training parameters
393
  args: Validated training arguments containing all necessary parameters
394
  for model training, including dataset paths, hyperparameters,
395
  and logging configuration
396
  """
397
- trainer = ModelTrainer(config, args)
398
  trainer.train()
 
10
  import torch
11
  from torch import nn, optim
12
  from torch.utils.data import DataLoader
13
+ from torch.utils.tensorboard.writer import SummaryWriter
14
  from typing import Dict, Tuple, Type, Union
15
+ import logging
16
+ from tqdm import tqdm
17
 
18
  from .parser import TrainingArguments
19
  from src.data.dataset import MatDataset, get_test_dataloaders
 
23
  get_ls_mse_per_folder,
24
  get_model_details,
25
  get_test_stats_plot,
26
+ get_error_images,
27
+ concat_complex_channel,
28
+ to_db
 
 
 
 
29
  )
30
+ from src.config.schemas import SystemConfig, ModelConfig
31
 
32
  # A union type representing supported model classes
33
  ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
 
47
  Attributes:
48
  MODEL_REGISTRY: Dictionary mapping model names to model classes
49
  system_config: OFDM system configuration
50
+ model_config: OFDM model configuration
51
+ args: Training arguments parsed from command line
52
  device: PyTorch device for computation
53
  writer: TensorBoard SummaryWriter for logging
54
+ model: Initialized Torch model instance
55
  optimizer: Torch optimizer for training
56
+ scheduler: Learning rate scheduler for training
57
  early_stopper: Helper for early stopping
58
+ train_loader: DataLoader for training set (used for training)
59
+ val_loader: DataLoader for validation set (used for validation)
60
+ test_loaders: Dictionary of test set DataLoaders (used for testing)
61
+ logger: Logger instance for logging messages
62
  """
63
 
64
  MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
 
67
  "fortitran": FortiTranEstimator,
68
  }
69
 
70
+ EXP_LR_GAMMA = 0.995
71
+
72
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments):
73
  """
74
  Initialize the ModelTrainer.
75
 
76
  Args:
77
+ system_config: OFDM system configuration dictionary from YAML file
78
+ model_config: OFDM model configuration dictionary from YAML file
79
+ args: Validated training arguments parsed from command line
80
  """
81
  self.system_config = system_config
82
+ self.model_config = model_config
83
  self.args = args
84
  self.device = torch.device(f"cuda:{args.cuda}")
85
  self.writer = self._setup_tensorboard()
86
+ self.logger = logging.getLogger(__name__)
87
 
88
  self.model = self._initialize_model()
89
  self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
90
+ self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
91
  self.early_stopper = EarlyStopping(patience=args.patience)
92
 
93
+ self.training_loss = nn.MSELoss()
 
94
 
95
  self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def _setup_tensorboard(self) -> SummaryWriter:
98
  """Set up TensorBoard logging.
99
 
 
121
  Initialized model instance of the specified type
122
  """
123
  model_class = self.MODEL_REGISTRY[self.args.model_name]
124
+ if model_class is LinearEstimator:
125
+ model = model_class(self.system_config, device=str(self.device))
126
+ else:
127
+ if self.model_config is None:
128
+ raise ValueError("model_config must be provided for non-linear models.")
129
+ model = model_class(self.system_config, self.model_config)
130
  num_params, model_summary = get_model_details(model)
131
+ self.logger.info("\n" + model_summary)
132
+ self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
133
+ self.writer.add_text("Model Summary", model_summary)
134
  self.writer.add_text("Number of Parameters", str(num_params))
 
135
  return model
136
 
137
  def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
138
+ pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
 
 
 
 
 
 
 
 
 
139
  # Training and validation dataloaders
140
  train_dataset = MatDataset(
141
  self.args.train_set,
142
+ pilot_dims
 
143
  )
144
  val_dataset = MatDataset(
145
  self.args.val_set,
146
+ pilot_dims
 
147
  )
 
148
  train_loader = DataLoader(
149
  train_dataset,
150
  batch_size=self.args.batch_size,
 
155
  batch_size=self.args.batch_size,
156
  shuffle=True
157
  )
 
 
158
  test_loaders = {
159
  "DS": get_test_dataloaders(
160
  self.args.test_set / "DS_test_set",
161
+ {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
162
  ),
163
  "MDS": get_test_dataloaders(
164
  self.args.test_set / "MDS_test_set",
165
+ {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
166
  ),
167
  "SNR": get_test_dataloaders(
168
  self.args.test_set / "SNR_test_set",
169
+ {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
170
  ),
171
  }
 
172
  return train_loader, val_loader, test_loaders
173
 
174
  def _log_test_results(
175
  self,
176
  epoch: int,
177
+ test_stats: Dict[str, Dict]
 
178
  ) -> None:
179
  """Log test results to TensorBoard.
180
 
181
+ Creates and logs visualizations for model performance across different test conditions.
 
182
 
183
  Args:
184
  epoch: Current training epoch
185
  test_stats: Dictionary of test statistics for the model
 
186
  """
187
  for key in ("DS", "MDS", "SNR"):
188
  # Plot test statistics
 
190
  tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
191
  figure=get_test_stats_plot(
192
  x_name=key,
193
+ stats=[test_stats[key]],
194
+ methods=[self.args.model_name]
195
  )
196
  )
197
 
198
  # Plot error images
199
+ predicted_channels = self._predict_channels(self.test_loaders[key])
 
 
 
200
  self.writer.add_figure(
201
  tag=f"{key} Error Images (Epoch:{epoch + 1})",
202
  figure=get_error_images(
 
209
  def _run_tests(self, epoch: int) -> None:
210
  """Run tests and log results.
211
 
212
+ Evaluates the model on all test datasets and logs performance metrics and visualizations.
 
213
 
214
  Args:
215
  epoch: Current training epoch
216
  """
217
+ ds_stats, mds_stats, snr_stats = self._get_all_test_stats()
 
 
 
 
 
 
 
 
 
 
218
 
219
  test_stats = {
220
  "DS": ds_stats,
 
222
  "SNR": snr_stats
223
  }
224
 
225
+ self._log_test_results(epoch, test_stats)
226
 
227
  def _log_final_metrics(self, final_epoch: int) -> None:
228
  """Log final training metrics and hyperparameters.
 
242
 
243
  try:
244
  for key in ("DS", "MDS", "SNR"):
245
+ ds_stats, mds_stats, snr_stats = self._get_all_test_stats()
 
 
 
 
246
  ls_stats = {
247
  "DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
248
  "MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
 
261
  key,
262
  {
263
  "LS": ls_stats[key][val],
264
+ self.args.model_name: stats[val]
265
  },
266
  val
267
  )
268
  except Exception as e:
269
  self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
270
 
271
+ def _compute_loss(self, estimated_channel, ideal_channel, loss_fn):
272
+ return loss_fn(
273
+ concat_complex_channel(estimated_channel),
274
+ concat_complex_channel(ideal_channel)
275
+ )
276
+
277
+ def _forward_pass(self, batch, model):
278
+ estimated_channel, ideal_channel, meta_data = batch
279
+ if hasattr(model, 'name') and model.name in ["fortitran", "MMSE"]:
280
+ h_est_re = model(torch.real(estimated_channel))
281
+ h_est_im = model(torch.imag(estimated_channel))
282
+ estimated_channel = torch.complex(h_est_re, h_est_im)
283
+ elif hasattr(model, 'name') and model.name == "adafortitran":
284
+ h_est_re = model(torch.real(estimated_channel), meta_data)
285
+ h_est_im = model(torch.imag(estimated_channel), meta_data)
286
+ estimated_channel = torch.complex(h_est_re, h_est_im)
287
+ else:
288
+ raise ValueError(f"Unknown model type: {getattr(model, 'name', type(model))}")
289
+ return estimated_channel, ideal_channel.to(model.device)
290
+
291
+ def _train_epoch(self):
292
+ train_loss = 0.0
293
+ self.model.train()
294
+ for batch in self.train_loader:
295
+ self.optimizer.zero_grad()
296
+ estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
297
+ output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
298
+ output.backward()
299
+ self.optimizer.step()
300
+ train_loss += (2 * output.item() * batch[0].size(0))
301
+ self.scheduler.step()
302
+ train_loss /= len(self.train_loader.dataset)
303
+ return train_loss
304
+
305
+ def _eval_model(self, eval_dataloader):
306
+ val_loss = 0.0
307
+ self.model.eval()
308
+ with torch.no_grad():
309
+ for batch in eval_dataloader:
310
+ estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
311
+ output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
312
+ val_loss += (2 * output.item() * batch[0].size(0))
313
+ val_loss /= len(eval_dataloader.dataset)
314
+ return val_loss
315
+
316
+ def _predict_channels(self, test_dataloaders):
317
+ channels = {}
318
+ sorted_loaders = sorted(
319
+ test_dataloaders,
320
+ key=lambda x: int(x[0].split("_")[1])
321
+ )
322
+ for name, test_dataloader in sorted_loaders:
323
+ with torch.no_grad():
324
+ batch = next(iter(test_dataloader))
325
+ estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
326
+ var, val = name.split("_")
327
+ channels[int(val)] = {
328
+ "estimated_channel": estimated_channels[0],
329
+ "ideal_channel": ideal_channels[0]
330
+ }
331
+ return channels
332
+
333
+ def _get_test_stats(self, test_dataloaders):
334
+ stats = {}
335
+ sorted_loaders = sorted(
336
+ test_dataloaders,
337
+ key=lambda x: int(x[0].split("_")[1])
338
+ )
339
+ for name, test_dataloader in sorted_loaders:
340
+ var, val = name.split("_")
341
+ test_loss = self._eval_model(test_dataloader)
342
+ db_error = to_db(test_loss)
343
+ self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
344
+ stats[int(val)] = db_error
345
+ return stats
346
+
347
+ def _get_all_test_stats(self):
348
+ ds_stats = self._get_test_stats(self.test_loaders["DS"])
349
+ mds_stats = self._get_test_stats(self.test_loaders["MDS"])
350
+ snr_stats = self._get_test_stats(self.test_loaders["SNR"])
351
+ return ds_stats, mds_stats, snr_stats
352
+
353
  def train(self) -> None:
354
  """Execute the training loop.
355
 
 
359
  - Early stopping when validation loss plateaus
360
  - Logging final metrics and results
361
  """
 
 
 
 
 
 
 
362
  epoch = None
363
+ pbar = tqdm(range(self.args.max_epoch), desc="Training")
 
 
 
 
 
 
364
  for epoch in pbar:
365
  # Training step
366
+ train_loss = self._train_epoch()
 
 
 
 
 
 
367
  self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
368
 
369
  # Validation step
370
+ val_loss = self._eval_model(self.val_loader)
371
  self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
372
 
373
+ # Update progress bar with loss info
374
+ pbar.set_description(
375
+ f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
 
376
 
377
  if self.early_stopper.early_stop(val_loss):
378
+ pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
 
 
 
379
  break
380
 
381
  # Periodic testing
382
  if (epoch + 1) % self.args.test_every_n == 0:
383
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
384
+ pbar.write(message)
 
 
 
385
  self._run_tests(epoch)
 
386
  self._log_final_metrics(epoch)
387
  self.writer.close()
388
 
389
 
390
+ def train(system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments) -> None:
391
  """
392
  Train an OFDM channel estimation model.
393
 
 
395
  with the specified configuration and runs the training process.
396
 
397
  Args:
398
+ system_config: OFDM system configuration dictionary from YAML file
399
+ model_config: OFDM model configuration dictionary from YAML file
400
  args: Validated training arguments containing all necessary parameters
401
  for model training, including dataset paths, hyperparameters,
402
  and logging configuration
403
  """
404
+ trainer = ModelTrainer(system_config, model_config, args)
405
  trainer.train()
src/models/adafortitran.py CHANGED
@@ -1,5 +1,5 @@
1
  from .fortitran import BaseFortiTranEstimator
2
- from src.config import SystemConfig, ModelConfig
3
 
4
 
5
  class AdaFortiTranEstimator(BaseFortiTranEstimator):
 
1
  from .fortitran import BaseFortiTranEstimator
2
+ from src.config.schemas import SystemConfig, ModelConfig
3
 
4
 
5
  class AdaFortiTranEstimator(BaseFortiTranEstimator):
src/models/fortitran.py CHANGED
@@ -3,7 +3,7 @@ from torch import nn
3
  import logging
4
  from typing import Tuple, List, Optional
5
 
6
- from src.config import SystemConfig, ModelConfig
7
  from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
8
  ChannelAdapter
9
 
 
3
  import logging
4
  from typing import Tuple, List, Optional
5
 
6
+ from src.config.schemas import SystemConfig, ModelConfig
7
  from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
8
  ChannelAdapter
9
 
src/models/linear.py CHANGED
@@ -27,16 +27,17 @@ class LinearEstimator(nn.Module):
27
  width (int): number of pilots across OFDM symbols
28
  """
29
 
30
- def __init__(self, config: SystemConfig) -> None:
31
  """Initialize the MMSE estimator.
32
 
33
  Args:
34
  config: Validated SystemConfig object containing OFDM system parameters
 
35
  """
36
  super().__init__()
37
 
38
  self.config = config
39
- self.device = torch.device(config.device)
40
  self.logger = logging.getLogger(__name__)
41
 
42
  # Extract dimensions from validated config
 
27
  width (int): number of pilots across OFDM symbols
28
  """
29
 
30
+ def __init__(self, config: SystemConfig, device: str = "cpu") -> None:
31
  """Initialize the MMSE estimator.
32
 
33
  Args:
34
  config: Validated SystemConfig object containing OFDM system parameters
35
+ device: Device to use for computation (cpu, cuda, etc.)
36
  """
37
  super().__init__()
38
 
39
  self.config = config
40
+ self.device = torch.device(device)
41
  self.logger = logging.getLogger(__name__)
42
 
43
  # Extract dimensions from validated config