Commit
·
4e938bd
1
Parent(s):
54d5c08
refactored trainer class
Browse files- .gitignore +1 -1
- requirements.txt +2 -1
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__pycache__/schemas.cpython-312.pyc +0 -0
- src/config/config_loader.py +10 -13
- src/config/schemas.py +81 -30
- src/main.py +78 -0
- src/main/parser.py +47 -4
- src/main/trainer.py +144 -137
- src/models/adafortitran.py +1 -1
- src/models/fortitran.py +1 -1
- src/models/linear.py +3 -2
.gitignore
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
.idea/
|
| 2 |
-
|
|
|
|
| 1 |
.idea/
|
| 2 |
+
+**/__pycache__/
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
torch
|
| 2 |
pydantic
|
| 3 |
yaml
|
| 4 |
-
scipy
|
|
|
|
|
|
| 1 |
torch
|
| 2 |
pydantic
|
| 3 |
yaml
|
| 4 |
+
scipy
|
| 5 |
+
tqdm
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
src/config/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
src/config/__pycache__/schemas.cpython-312.pyc
ADDED
|
Binary file (9.4 kB). View file
|
|
|
src/config/config_loader.py
CHANGED
|
@@ -21,10 +21,8 @@ class ConfigLoader:
|
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
model_config_path: Path to YAML configuration file for model-related parameters
|
| 23 |
|
| 24 |
-
|
| 25 |
Returns:
|
| 26 |
-
ModelConfig: Validated
|
| 27 |
-
SystemConfig: Validated system configuration object
|
| 28 |
|
| 29 |
Raises:
|
| 30 |
FileNotFoundError: If one of the config files doesn't exist
|
|
@@ -34,10 +32,10 @@ class ConfigLoader:
|
|
| 34 |
model_config_path = Path(model_config_path)
|
| 35 |
|
| 36 |
if not system_config_path.exists():
|
| 37 |
-
raise FileNotFoundError(f"
|
| 38 |
|
| 39 |
if not model_config_path.exists():
|
| 40 |
-
raise FileNotFoundError(f"
|
| 41 |
|
| 42 |
try:
|
| 43 |
with open(system_config_path, 'r') as f:
|
|
@@ -55,16 +53,15 @@ class ConfigLoader:
|
|
| 55 |
system_config = SystemConfig(**system_raw_config)
|
| 56 |
self.logger.info(f"Successfully loaded system config from {system_config_path}")
|
| 57 |
except ValidationError as e:
|
| 58 |
-
raise ValueError(f"
|
| 59 |
-
if system_config:
|
| 60 |
-
try:
|
| 61 |
-
model_config = ModelConfig(system_config, **model_raw_config)
|
| 62 |
-
self.logger.info(f"Successfully loaded model config from {model_config_path}")
|
| 63 |
-
except ValidationError as e:
|
| 64 |
-
raise ValueError(f"Configuration validation for {model_config_path} failed:\n{e}")
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
|
|
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
model_config_path: Path to YAML configuration file for model-related parameters
|
| 23 |
|
|
|
|
| 24 |
Returns:
|
| 25 |
+
Tuple of (SystemConfig, ModelConfig): Validated configuration objects
|
|
|
|
| 26 |
|
| 27 |
Raises:
|
| 28 |
FileNotFoundError: If one of the config files doesn't exist
|
|
|
|
| 32 |
model_config_path = Path(model_config_path)
|
| 33 |
|
| 34 |
if not system_config_path.exists():
|
| 35 |
+
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
|
| 36 |
|
| 37 |
if not model_config_path.exists():
|
| 38 |
+
raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
|
| 39 |
|
| 40 |
try:
|
| 41 |
with open(system_config_path, 'r') as f:
|
|
|
|
| 53 |
system_config = SystemConfig(**system_raw_config)
|
| 54 |
self.logger.info(f"Successfully loaded system config from {system_config_path}")
|
| 55 |
except ValidationError as e:
|
| 56 |
+
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
try:
|
| 59 |
+
model_config = ModelConfig(**model_raw_config)
|
| 60 |
+
self.logger.info(f"Successfully loaded model config from {model_config_path}")
|
| 61 |
+
except ValidationError as e:
|
| 62 |
+
raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
|
| 63 |
|
| 64 |
+
return system_config, model_config
|
| 65 |
|
| 66 |
|
| 67 |
def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
src/config/schemas.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
-
from typing import Self, Tuple
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
@@ -14,8 +14,19 @@ class PilotParams(BaseModel):
|
|
| 14 |
|
| 15 |
|
| 16 |
class ModelParams(BaseModel):
|
| 17 |
-
patch_size: Tuple[int, int] = Field(
|
| 18 |
-
num_layers: int = Field(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
device: str = Field(default="cpu", description="Device to use")
|
| 20 |
|
| 21 |
@model_validator(mode='after')
|
|
@@ -103,39 +114,79 @@ class SystemConfig(BaseModel):
|
|
| 103 |
|
| 104 |
|
| 105 |
class ModelConfig(BaseModel):
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
@model_validator(mode='after')
|
| 110 |
-
def
|
| 111 |
-
"""
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
f"OFDM sub-carriers ({self.system.ofdm.num_scs})"
|
| 124 |
-
)
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
if
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
f"by patch height ({patch_height}) for clean patching"
|
| 131 |
-
)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
model_config = {"extra": "forbid"}
|
|
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
+
from typing import Self, Tuple, List, Optional
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class ModelParams(BaseModel):
|
| 17 |
+
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 18 |
+
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 19 |
+
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 20 |
+
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 21 |
+
activation: str = Field(default="gelu", description="Activation function")
|
| 22 |
+
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
|
| 23 |
+
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 24 |
+
pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
|
| 25 |
+
adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
|
| 26 |
+
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 27 |
+
default=None,
|
| 28 |
+
description="Hidden sizes for channel adaptation layers"
|
| 29 |
+
)
|
| 30 |
device: str = Field(default="cpu", description="Device to use")
|
| 31 |
|
| 32 |
@model_validator(mode='after')
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
class ModelConfig(BaseModel):
|
| 117 |
+
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 118 |
+
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 119 |
+
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 120 |
+
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 121 |
+
activation: str = Field(default="gelu", description="Activation function")
|
| 122 |
+
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
|
| 123 |
+
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 124 |
+
pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
|
| 125 |
+
adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
|
| 126 |
+
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 127 |
+
default=None,
|
| 128 |
+
description="Hidden sizes for channel adaptation layers"
|
| 129 |
+
)
|
| 130 |
+
device: str = Field(default="cpu", description="Device to use")
|
| 131 |
|
| 132 |
@model_validator(mode='after')
|
| 133 |
+
def validate_device(self) -> Self:
|
| 134 |
+
"""Validate that the specified device is available."""
|
| 135 |
+
device_str = self.device.lower()
|
| 136 |
|
| 137 |
+
# Handle 'auto' case - automatically select best available device
|
| 138 |
+
if device_str == 'auto':
|
| 139 |
+
if torch.cuda.is_available():
|
| 140 |
+
self.device = 'cuda'
|
| 141 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 142 |
+
self.device = 'mps' # Apple Silicon
|
| 143 |
+
else:
|
| 144 |
+
self.device = 'cpu'
|
| 145 |
+
return self
|
| 146 |
|
| 147 |
+
# Validate CPU
|
| 148 |
+
if device_str == 'cpu':
|
| 149 |
+
return self
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
# Validate CUDA devices
|
| 152 |
+
if device_str.startswith('cuda'):
|
| 153 |
+
if not torch.cuda.is_available():
|
| 154 |
+
raise ValueError("CUDA is not available on this system")
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# Handle specific CUDA device (e.g., 'cuda:0', 'cuda:1')
|
| 157 |
+
if ':' in device_str:
|
| 158 |
+
try:
|
| 159 |
+
device_id = int(device_str.split(':')[1])
|
| 160 |
+
if device_id >= torch.cuda.device_count():
|
| 161 |
+
available_devices = list(range(torch.cuda.device_count()))
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"CUDA device {device_id} not available. "
|
| 164 |
+
f"Available CUDA devices: {available_devices}"
|
| 165 |
+
)
|
| 166 |
+
except (ValueError, IndexError) as e:
|
| 167 |
+
if "invalid literal" in str(e):
|
| 168 |
+
raise ValueError(f"Invalid CUDA device format: {device_str}")
|
| 169 |
+
raise
|
| 170 |
|
| 171 |
+
return self
|
| 172 |
+
|
| 173 |
+
# Validate MPS (Apple Silicon)
|
| 174 |
+
if device_str == 'mps':
|
| 175 |
+
if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
|
| 176 |
+
raise ValueError("MPS is not available on this system")
|
| 177 |
+
return self
|
| 178 |
+
|
| 179 |
+
# If we get here, the device is not recognized
|
| 180 |
+
available_devices = ['cpu']
|
| 181 |
+
if torch.cuda.is_available():
|
| 182 |
+
cuda_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
|
| 183 |
+
available_devices.extend(['cuda'] + cuda_devices)
|
| 184 |
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 185 |
+
available_devices.append('mps')
|
| 186 |
+
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Unsupported device: '{self.device}'. "
|
| 189 |
+
f"Available devices: {available_devices}"
|
| 190 |
+
)
|
| 191 |
|
| 192 |
model_config = {"extra": "forbid"}
|
src/main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Main entry point for OFDM channel estimation model training.
|
| 4 |
+
|
| 5 |
+
This script provides the command-line interface for training OFDM channel estimation
|
| 6 |
+
models. It loads configuration files, parses command-line arguments, and initiates
|
| 7 |
+
the training process.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from src.main.parser import parse_arguments
|
| 15 |
+
from src.main.trainer import train
|
| 16 |
+
from src.config.config_loader import load_config
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def setup_logging(log_level: str) -> None:
|
| 20 |
+
"""Set up logging configuration.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 24 |
+
"""
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=getattr(logging, log_level.upper()),
|
| 27 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 28 |
+
handlers=[
|
| 29 |
+
logging.StreamHandler(sys.stdout),
|
| 30 |
+
logging.FileHandler('training.log')
|
| 31 |
+
]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
"""Main entry point for the training script."""
|
| 37 |
+
try:
|
| 38 |
+
# Parse command-line arguments
|
| 39 |
+
args = parse_arguments()
|
| 40 |
+
|
| 41 |
+
# Set up logging
|
| 42 |
+
setup_logging(args.python_log_level)
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
logger.info("Starting OFDM channel estimation model training")
|
| 46 |
+
logger.info(f"Model: {args.model_name}")
|
| 47 |
+
logger.info(f"System config: {args.system_config_path}")
|
| 48 |
+
logger.info(f"Model config: {args.model_config_path}")
|
| 49 |
+
logger.info(f"Experiment ID: {args.exp_id}")
|
| 50 |
+
|
| 51 |
+
# Load and validate configurations
|
| 52 |
+
logger.info("Loading configuration files...")
|
| 53 |
+
system_config, model_config = load_config(
|
| 54 |
+
args.system_config_path,
|
| 55 |
+
args.model_config_path
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
logger.info("Configuration loaded successfully")
|
| 59 |
+
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 60 |
+
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 61 |
+
logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
|
| 62 |
+
|
| 63 |
+
# Start training
|
| 64 |
+
logger.info("Initializing training...")
|
| 65 |
+
train(system_config, model_config, args)
|
| 66 |
+
|
| 67 |
+
logger.info("Training completed successfully")
|
| 68 |
+
|
| 69 |
+
except KeyboardInterrupt:
|
| 70 |
+
logger.info("Training interrupted by user")
|
| 71 |
+
sys.exit(1)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Training failed with error: {str(e)}")
|
| 74 |
+
sys.exit(1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
src/main/parser.py
CHANGED
|
@@ -10,6 +10,14 @@ of training runs.
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
| 12 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass
|
|
@@ -23,6 +31,7 @@ class TrainingArguments:
|
|
| 23 |
# Model Configuration
|
| 24 |
model_name: Supports Linear, AdaFortiTran, or FortiTran training
|
| 25 |
system_config_path: Path to OFDM system configuration file
|
|
|
|
| 26 |
|
| 27 |
# Dataset Paths
|
| 28 |
train_set: Path to training dataset directory
|
|
@@ -39,6 +48,8 @@ class TrainingArguments:
|
|
| 39 |
lr: Learning rate for optimizer
|
| 40 |
max_epoch: Maximum number of training epochs
|
| 41 |
patience: Early stopping patience in epochs
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Hardware & Evaluation
|
| 44 |
cuda: CUDA device index
|
|
@@ -48,6 +59,7 @@ class TrainingArguments:
|
|
| 48 |
# Model Configuration
|
| 49 |
model_name: str
|
| 50 |
system_config_path: Path
|
|
|
|
| 51 |
|
| 52 |
# Dataset Paths
|
| 53 |
train_set: Path
|
|
@@ -64,6 +76,8 @@ class TrainingArguments:
|
|
| 64 |
lr: float = 1e-3
|
| 65 |
max_epoch: int = 10
|
| 66 |
patience: int = 3
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Hardware & Evaluation
|
| 69 |
cuda: int = 0
|
|
@@ -84,16 +98,22 @@ class TrainingArguments:
|
|
| 84 |
def _validate_paths(self) -> None:
|
| 85 |
"""Validate path-related arguments.
|
| 86 |
|
| 87 |
-
Checks that the config
|
| 88 |
|
| 89 |
Raises:
|
| 90 |
-
ValueError: If the config
|
| 91 |
"""
|
| 92 |
if not self.system_config_path.exists():
|
| 93 |
-
raise ValueError(f"
|
| 94 |
|
| 95 |
if not self.system_config_path.suffix == '.yaml':
|
| 96 |
-
raise ValueError(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def _validate_numeric_args(self) -> None:
|
| 99 |
"""Validate numeric arguments.
|
|
@@ -159,6 +179,12 @@ def parse_arguments() -> TrainingArguments:
|
|
| 159 |
required=True,
|
| 160 |
help='Path to YAML file containing OFDM system parameters'
|
| 161 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
required.add_argument(
|
| 163 |
'--train_set',
|
| 164 |
type=Path,
|
|
@@ -234,8 +260,25 @@ def parse_arguments() -> TrainingArguments:
|
|
| 234 |
default=1e-3,
|
| 235 |
help='Initial learning rate'
|
| 236 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
args = parser.parse_args()
|
| 239 |
|
|
|
|
|
|
|
|
|
|
| 240 |
# Create and validate TrainingArguments
|
| 241 |
return TrainingArguments(**vars(args))
|
|
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
| 12 |
import argparse
|
| 13 |
+
from enum import Enum
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LossType(Enum):
|
| 17 |
+
"""Enumeration of supported loss functions."""
|
| 18 |
+
MSE = "mse"
|
| 19 |
+
MAE = "mae"
|
| 20 |
+
HUBER = "huber"
|
| 21 |
|
| 22 |
|
| 23 |
@dataclass
|
|
|
|
| 31 |
# Model Configuration
|
| 32 |
model_name: Supports Linear, AdaFortiTran, or FortiTran training
|
| 33 |
system_config_path: Path to OFDM system configuration file
|
| 34 |
+
model_config_path: Path to model configuration file
|
| 35 |
|
| 36 |
# Dataset Paths
|
| 37 |
train_set: Path to training dataset directory
|
|
|
|
| 48 |
lr: Learning rate for optimizer
|
| 49 |
max_epoch: Maximum number of training epochs
|
| 50 |
patience: Early stopping patience in epochs
|
| 51 |
+
loss_type: Type of loss function to use
|
| 52 |
+
return_type: Type of data to return from dataset
|
| 53 |
|
| 54 |
# Hardware & Evaluation
|
| 55 |
cuda: CUDA device index
|
|
|
|
| 59 |
# Model Configuration
|
| 60 |
model_name: str
|
| 61 |
system_config_path: Path
|
| 62 |
+
model_config_path: Path
|
| 63 |
|
| 64 |
# Dataset Paths
|
| 65 |
train_set: Path
|
|
|
|
| 76 |
lr: float = 1e-3
|
| 77 |
max_epoch: int = 10
|
| 78 |
patience: int = 3
|
| 79 |
+
loss_type: LossType = LossType.MSE
|
| 80 |
+
return_type: str = "complex"
|
| 81 |
|
| 82 |
# Hardware & Evaluation
|
| 83 |
cuda: int = 0
|
|
|
|
| 98 |
def _validate_paths(self) -> None:
|
| 99 |
"""Validate path-related arguments.
|
| 100 |
|
| 101 |
+
Checks that the config files exist and have the correct extension.
|
| 102 |
|
| 103 |
Raises:
|
| 104 |
+
ValueError: If the config files don't exist or aren't YAML files
|
| 105 |
"""
|
| 106 |
if not self.system_config_path.exists():
|
| 107 |
+
raise ValueError(f"System config file not found: {self.system_config_path}")
|
| 108 |
|
| 109 |
if not self.system_config_path.suffix == '.yaml':
|
| 110 |
+
raise ValueError(f"System config file must be a .yaml file: {self.system_config_path}")
|
| 111 |
+
|
| 112 |
+
if not self.model_config_path.exists():
|
| 113 |
+
raise ValueError(f"Model config file not found: {self.model_config_path}")
|
| 114 |
+
|
| 115 |
+
if not self.model_config_path.suffix == '.yaml':
|
| 116 |
+
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
| 117 |
|
| 118 |
def _validate_numeric_args(self) -> None:
|
| 119 |
"""Validate numeric arguments.
|
|
|
|
| 179 |
required=True,
|
| 180 |
help='Path to YAML file containing OFDM system parameters'
|
| 181 |
)
|
| 182 |
+
required.add_argument(
|
| 183 |
+
'--model_config_path',
|
| 184 |
+
type=Path,
|
| 185 |
+
required=True,
|
| 186 |
+
help='Path to YAML file containing model architecture parameters'
|
| 187 |
+
)
|
| 188 |
required.add_argument(
|
| 189 |
'--train_set',
|
| 190 |
type=Path,
|
|
|
|
| 260 |
default=1e-3,
|
| 261 |
help='Initial learning rate'
|
| 262 |
)
|
| 263 |
+
optional.add_argument(
|
| 264 |
+
'--loss_type',
|
| 265 |
+
type=str,
|
| 266 |
+
default="mse",
|
| 267 |
+
choices=['mse', 'mae', 'huber'],
|
| 268 |
+
help='Loss function type'
|
| 269 |
+
)
|
| 270 |
+
optional.add_argument(
|
| 271 |
+
'--return_type',
|
| 272 |
+
type=str,
|
| 273 |
+
default="complex",
|
| 274 |
+
choices=['complex', 'real'],
|
| 275 |
+
help='Type of data to return from dataset'
|
| 276 |
+
)
|
| 277 |
|
| 278 |
args = parser.parse_args()
|
| 279 |
|
| 280 |
+
# Convert loss_type string to enum
|
| 281 |
+
args.loss_type = LossType(args.loss_type)
|
| 282 |
+
|
| 283 |
# Create and validate TrainingArguments
|
| 284 |
return TrainingArguments(**vars(args))
|
src/main/trainer.py
CHANGED
|
@@ -10,8 +10,10 @@ training loop management, evaluation, and result logging.
|
|
| 10 |
import torch
|
| 11 |
from torch import nn, optim
|
| 12 |
from torch.utils.data import DataLoader
|
| 13 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
from typing import Dict, Tuple, Type, Union
|
|
|
|
|
|
|
| 15 |
|
| 16 |
from .parser import TrainingArguments
|
| 17 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
|
@@ -21,14 +23,11 @@ from src.utils import (
|
|
| 21 |
get_ls_mse_per_folder,
|
| 22 |
get_model_details,
|
| 23 |
get_test_stats_plot,
|
| 24 |
-
get_error_images
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
get_all_test_stats,
|
| 28 |
-
train_epoch,
|
| 29 |
-
eval_model,
|
| 30 |
-
predict_channels
|
| 31 |
)
|
|
|
|
| 32 |
|
| 33 |
# A union type representing supported model classes
|
| 34 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
|
@@ -48,16 +47,18 @@ class ModelTrainer:
|
|
| 48 |
Attributes:
|
| 49 |
MODEL_REGISTRY: Dictionary mapping model names to model classes
|
| 50 |
system_config: OFDM system configuration
|
| 51 |
-
|
|
|
|
| 52 |
device: PyTorch device for computation
|
| 53 |
writer: TensorBoard SummaryWriter for logging
|
| 54 |
-
model: Initialized model instance
|
| 55 |
optimizer: Torch optimizer for training
|
| 56 |
-
scheduler: Learning rate scheduler
|
| 57 |
early_stopper: Helper for early stopping
|
| 58 |
-
train_loader: DataLoader for training set
|
| 59 |
-
val_loader: DataLoader for validation set
|
| 60 |
-
test_loaders: Dictionary of test set DataLoaders
|
|
|
|
| 61 |
"""
|
| 62 |
|
| 63 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
|
@@ -66,47 +67,33 @@ class ModelTrainer:
|
|
| 66 |
"fortitran": FortiTranEstimator,
|
| 67 |
}
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
"""
|
| 71 |
Initialize the ModelTrainer.
|
| 72 |
|
| 73 |
Args:
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
"""
|
| 77 |
self.system_config = system_config
|
|
|
|
| 78 |
self.args = args
|
| 79 |
self.device = torch.device(f"cuda:{args.cuda}")
|
| 80 |
self.writer = self._setup_tensorboard()
|
|
|
|
| 81 |
|
| 82 |
self.model = self._initialize_model()
|
| 83 |
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
|
| 84 |
-
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=
|
| 85 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
| 86 |
|
| 87 |
-
self.training_loss =
|
| 88 |
-
self.comparison_loss = nn.MSELoss() # used for test set evaluation
|
| 89 |
|
| 90 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
| 91 |
|
| 92 |
-
def _get_loss_function(self) -> nn.Module:
|
| 93 |
-
"""Get the appropriate loss function based on arguments.
|
| 94 |
-
|
| 95 |
-
Returns:
|
| 96 |
-
The selected PyTorch loss function based on args.loss_type
|
| 97 |
-
|
| 98 |
-
Raises:
|
| 99 |
-
ValueError: If an unsupported loss type is specified
|
| 100 |
-
"""
|
| 101 |
-
if self.args.loss_type == LossType.MSE:
|
| 102 |
-
return nn.MSELoss()
|
| 103 |
-
elif self.args.loss_type == LossType.MAE:
|
| 104 |
-
return nn.L1Loss()
|
| 105 |
-
elif self.args.loss_type == LossType.HUBER:
|
| 106 |
-
return nn.HuberLoss()
|
| 107 |
-
else:
|
| 108 |
-
raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
|
| 109 |
-
|
| 110 |
def _setup_tensorboard(self) -> SummaryWriter:
|
| 111 |
"""Set up TensorBoard logging.
|
| 112 |
|
|
@@ -134,38 +121,30 @@ class ModelTrainer:
|
|
| 134 |
Initialized model instance of the specified type
|
| 135 |
"""
|
| 136 |
model_class = self.MODEL_REGISTRY[self.args.model_name]
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
num_params, model_summary = get_model_details(model)
|
| 140 |
-
|
| 141 |
-
|
|
|
|
| 142 |
self.writer.add_text("Number of Parameters", str(num_params))
|
| 143 |
-
|
| 144 |
return model
|
| 145 |
|
| 146 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
Creates DataLoader instances for:
|
| 150 |
-
- Training dataset
|
| 151 |
-
- Validation dataset
|
| 152 |
-
- Test datasets grouped by test condition (DS, MDS, SNR)
|
| 153 |
-
|
| 154 |
-
Returns:
|
| 155 |
-
Tuple containing (train_loader, val_loader, test_loaders_dict)
|
| 156 |
-
"""
|
| 157 |
# Training and validation dataloaders
|
| 158 |
train_dataset = MatDataset(
|
| 159 |
self.args.train_set,
|
| 160 |
-
|
| 161 |
-
return_type=self.config["return_type"]
|
| 162 |
)
|
| 163 |
val_dataset = MatDataset(
|
| 164 |
self.args.val_set,
|
| 165 |
-
|
| 166 |
-
return_type=self.config["return_type"]
|
| 167 |
)
|
| 168 |
-
|
| 169 |
train_loader = DataLoader(
|
| 170 |
train_dataset,
|
| 171 |
batch_size=self.args.batch_size,
|
|
@@ -176,43 +155,34 @@ class ModelTrainer:
|
|
| 176 |
batch_size=self.args.batch_size,
|
| 177 |
shuffle=True
|
| 178 |
)
|
| 179 |
-
|
| 180 |
-
# Test dataloaders
|
| 181 |
test_loaders = {
|
| 182 |
"DS": get_test_dataloaders(
|
| 183 |
self.args.test_set / "DS_test_set",
|
| 184 |
-
|
| 185 |
-
self.config["return_type"]
|
| 186 |
),
|
| 187 |
"MDS": get_test_dataloaders(
|
| 188 |
self.args.test_set / "MDS_test_set",
|
| 189 |
-
|
| 190 |
-
self.config["return_type"]
|
| 191 |
),
|
| 192 |
"SNR": get_test_dataloaders(
|
| 193 |
self.args.test_set / "SNR_test_set",
|
| 194 |
-
|
| 195 |
-
self.config["return_type"]
|
| 196 |
),
|
| 197 |
}
|
| 198 |
-
|
| 199 |
return train_loader, val_loader, test_loaders
|
| 200 |
|
| 201 |
def _log_test_results(
|
| 202 |
self,
|
| 203 |
epoch: int,
|
| 204 |
-
test_stats: Dict[str, Dict]
|
| 205 |
-
ls_stats: Dict[str, Dict]
|
| 206 |
) -> None:
|
| 207 |
"""Log test results to TensorBoard.
|
| 208 |
|
| 209 |
-
Creates and logs visualizations
|
| 210 |
-
baseline LS estimator across different test conditions.
|
| 211 |
|
| 212 |
Args:
|
| 213 |
epoch: Current training epoch
|
| 214 |
test_stats: Dictionary of test statistics for the model
|
| 215 |
-
ls_stats: Dictionary of test statistics for the LS baseline
|
| 216 |
"""
|
| 217 |
for key in ("DS", "MDS", "SNR"):
|
| 218 |
# Plot test statistics
|
|
@@ -220,16 +190,13 @@ class ModelTrainer:
|
|
| 220 |
tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
|
| 221 |
figure=get_test_stats_plot(
|
| 222 |
x_name=key,
|
| 223 |
-
stats=[test_stats[key]
|
| 224 |
-
methods=[self.
|
| 225 |
)
|
| 226 |
)
|
| 227 |
|
| 228 |
# Plot error images
|
| 229 |
-
predicted_channels =
|
| 230 |
-
self.model,
|
| 231 |
-
self.test_loaders[key]
|
| 232 |
-
)
|
| 233 |
self.writer.add_figure(
|
| 234 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
| 235 |
figure=get_error_images(
|
|
@@ -242,23 +209,12 @@ class ModelTrainer:
|
|
| 242 |
def _run_tests(self, epoch: int) -> None:
|
| 243 |
"""Run tests and log results.
|
| 244 |
|
| 245 |
-
Evaluates the model on all test datasets
|
| 246 |
-
and logs performance metrics and visualizations.
|
| 247 |
|
| 248 |
Args:
|
| 249 |
epoch: Current training epoch
|
| 250 |
"""
|
| 251 |
-
ds_stats, mds_stats, snr_stats =
|
| 252 |
-
self.model,
|
| 253 |
-
self.test_loaders,
|
| 254 |
-
self.comparison_loss
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
ls_stats = {
|
| 258 |
-
"DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
|
| 259 |
-
"MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
|
| 260 |
-
"SNR": get_ls_mse_per_folder(self.args.test_set / "SNR_test_set")
|
| 261 |
-
}
|
| 262 |
|
| 263 |
test_stats = {
|
| 264 |
"DS": ds_stats,
|
|
@@ -266,7 +222,7 @@ class ModelTrainer:
|
|
| 266 |
"SNR": snr_stats
|
| 267 |
}
|
| 268 |
|
| 269 |
-
self._log_test_results(epoch, test_stats
|
| 270 |
|
| 271 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
| 272 |
"""Log final training metrics and hyperparameters.
|
|
@@ -286,11 +242,7 @@ class ModelTrainer:
|
|
| 286 |
|
| 287 |
try:
|
| 288 |
for key in ("DS", "MDS", "SNR"):
|
| 289 |
-
ds_stats, mds_stats, snr_stats =
|
| 290 |
-
self.model,
|
| 291 |
-
self.test_loaders,
|
| 292 |
-
self.comparison_loss
|
| 293 |
-
)
|
| 294 |
ls_stats = {
|
| 295 |
"DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
|
| 296 |
"MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
|
|
@@ -309,13 +261,95 @@ class ModelTrainer:
|
|
| 309 |
key,
|
| 310 |
{
|
| 311 |
"LS": ls_stats[key][val],
|
| 312 |
-
self.
|
| 313 |
},
|
| 314 |
val
|
| 315 |
)
|
| 316 |
except Exception as e:
|
| 317 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
def train(self) -> None:
|
| 320 |
"""Execute the training loop.
|
| 321 |
|
|
@@ -325,62 +359,35 @@ class ModelTrainer:
|
|
| 325 |
- Early stopping when validation loss plateaus
|
| 326 |
- Logging final metrics and results
|
| 327 |
"""
|
| 328 |
-
try:
|
| 329 |
-
from tqdm import tqdm
|
| 330 |
-
use_tqdm = True
|
| 331 |
-
except ImportError:
|
| 332 |
-
use_tqdm = False
|
| 333 |
-
print("tqdm not found, progress bar will not be displayed")
|
| 334 |
-
|
| 335 |
epoch = None
|
| 336 |
-
|
| 337 |
-
# Create progress bar if tqdm is available
|
| 338 |
-
if use_tqdm:
|
| 339 |
-
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
| 340 |
-
else:
|
| 341 |
-
pbar = range(self.args.max_epoch)
|
| 342 |
-
|
| 343 |
for epoch in pbar:
|
| 344 |
# Training step
|
| 345 |
-
train_loss =
|
| 346 |
-
self.model,
|
| 347 |
-
self.optimizer,
|
| 348 |
-
self.training_loss,
|
| 349 |
-
self.scheduler,
|
| 350 |
-
self.train_loader
|
| 351 |
-
)
|
| 352 |
self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
|
| 353 |
|
| 354 |
# Validation step
|
| 355 |
-
val_loss =
|
| 356 |
self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
|
| 357 |
|
| 358 |
-
# Update progress bar with loss info
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
| 362 |
|
| 363 |
if self.early_stopper.early_stop(val_loss):
|
| 364 |
-
|
| 365 |
-
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
| 366 |
-
else:
|
| 367 |
-
print(f"Early stopping triggered at epoch {epoch + 1}")
|
| 368 |
break
|
| 369 |
|
| 370 |
# Periodic testing
|
| 371 |
if (epoch + 1) % self.args.test_every_n == 0:
|
| 372 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 373 |
-
|
| 374 |
-
pbar.write(message)
|
| 375 |
-
else:
|
| 376 |
-
print(message)
|
| 377 |
self._run_tests(epoch)
|
| 378 |
-
|
| 379 |
self._log_final_metrics(epoch)
|
| 380 |
self.writer.close()
|
| 381 |
|
| 382 |
|
| 383 |
-
def train(
|
| 384 |
"""
|
| 385 |
Train an OFDM channel estimation model.
|
| 386 |
|
|
@@ -388,11 +395,11 @@ def train(config: Dict, args: TrainingArguments) -> None:
|
|
| 388 |
with the specified configuration and runs the training process.
|
| 389 |
|
| 390 |
Args:
|
| 391 |
-
|
| 392 |
-
|
| 393 |
args: Validated training arguments containing all necessary parameters
|
| 394 |
for model training, including dataset paths, hyperparameters,
|
| 395 |
and logging configuration
|
| 396 |
"""
|
| 397 |
-
trainer = ModelTrainer(
|
| 398 |
trainer.train()
|
|
|
|
| 10 |
import torch
|
| 11 |
from torch import nn, optim
|
| 12 |
from torch.utils.data import DataLoader
|
| 13 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
| 14 |
from typing import Dict, Tuple, Type, Union
|
| 15 |
+
import logging
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
|
| 18 |
from .parser import TrainingArguments
|
| 19 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
|
|
|
| 23 |
get_ls_mse_per_folder,
|
| 24 |
get_model_details,
|
| 25 |
get_test_stats_plot,
|
| 26 |
+
get_error_images,
|
| 27 |
+
concat_complex_channel,
|
| 28 |
+
to_db
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
+
from src.config.schemas import SystemConfig, ModelConfig
|
| 31 |
|
| 32 |
# A union type representing supported model classes
|
| 33 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
|
|
|
| 47 |
Attributes:
|
| 48 |
MODEL_REGISTRY: Dictionary mapping model names to model classes
|
| 49 |
system_config: OFDM system configuration
|
| 50 |
+
model_config: OFDM model configuration
|
| 51 |
+
args: Training arguments parsed from command line
|
| 52 |
device: PyTorch device for computation
|
| 53 |
writer: TensorBoard SummaryWriter for logging
|
| 54 |
+
model: Initialized Torch model instance
|
| 55 |
optimizer: Torch optimizer for training
|
| 56 |
+
scheduler: Learning rate scheduler for training
|
| 57 |
early_stopper: Helper for early stopping
|
| 58 |
+
train_loader: DataLoader for training set (used for training)
|
| 59 |
+
val_loader: DataLoader for validation set (used for validation)
|
| 60 |
+
test_loaders: Dictionary of test set DataLoaders (used for testing)
|
| 61 |
+
logger: Logger instance for logging messages
|
| 62 |
"""
|
| 63 |
|
| 64 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
|
|
|
| 67 |
"fortitran": FortiTranEstimator,
|
| 68 |
}
|
| 69 |
|
| 70 |
+
EXP_LR_GAMMA = 0.995
|
| 71 |
+
|
| 72 |
+
def __init__(self, system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments):
|
| 73 |
"""
|
| 74 |
Initialize the ModelTrainer.
|
| 75 |
|
| 76 |
Args:
|
| 77 |
+
system_config: OFDM system configuration dictionary from YAML file
|
| 78 |
+
model_config: OFDM model configuration dictionary from YAML file
|
| 79 |
+
args: Validated training arguments parsed from command line
|
| 80 |
"""
|
| 81 |
self.system_config = system_config
|
| 82 |
+
self.model_config = model_config
|
| 83 |
self.args = args
|
| 84 |
self.device = torch.device(f"cuda:{args.cuda}")
|
| 85 |
self.writer = self._setup_tensorboard()
|
| 86 |
+
self.logger = logging.getLogger(__name__)
|
| 87 |
|
| 88 |
self.model = self._initialize_model()
|
| 89 |
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
|
| 90 |
+
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
|
| 91 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
| 92 |
|
| 93 |
+
self.training_loss = nn.MSELoss()
|
|
|
|
| 94 |
|
| 95 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def _setup_tensorboard(self) -> SummaryWriter:
|
| 98 |
"""Set up TensorBoard logging.
|
| 99 |
|
|
|
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
model_class = self.MODEL_REGISTRY[self.args.model_name]
|
| 124 |
+
if model_class is LinearEstimator:
|
| 125 |
+
model = model_class(self.system_config, device=str(self.device))
|
| 126 |
+
else:
|
| 127 |
+
if self.model_config is None:
|
| 128 |
+
raise ValueError("model_config must be provided for non-linear models.")
|
| 129 |
+
model = model_class(self.system_config, self.model_config)
|
| 130 |
num_params, model_summary = get_model_details(model)
|
| 131 |
+
self.logger.info("\n" + model_summary)
|
| 132 |
+
self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
|
| 133 |
+
self.writer.add_text("Model Summary", model_summary)
|
| 134 |
self.writer.add_text("Number of Parameters", str(num_params))
|
|
|
|
| 135 |
return model
|
| 136 |
|
| 137 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
| 138 |
+
pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# Training and validation dataloaders
|
| 140 |
train_dataset = MatDataset(
|
| 141 |
self.args.train_set,
|
| 142 |
+
pilot_dims
|
|
|
|
| 143 |
)
|
| 144 |
val_dataset = MatDataset(
|
| 145 |
self.args.val_set,
|
| 146 |
+
pilot_dims
|
|
|
|
| 147 |
)
|
|
|
|
| 148 |
train_loader = DataLoader(
|
| 149 |
train_dataset,
|
| 150 |
batch_size=self.args.batch_size,
|
|
|
|
| 155 |
batch_size=self.args.batch_size,
|
| 156 |
shuffle=True
|
| 157 |
)
|
|
|
|
|
|
|
| 158 |
test_loaders = {
|
| 159 |
"DS": get_test_dataloaders(
|
| 160 |
self.args.test_set / "DS_test_set",
|
| 161 |
+
{"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
|
|
|
|
| 162 |
),
|
| 163 |
"MDS": get_test_dataloaders(
|
| 164 |
self.args.test_set / "MDS_test_set",
|
| 165 |
+
{"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
|
|
|
|
| 166 |
),
|
| 167 |
"SNR": get_test_dataloaders(
|
| 168 |
self.args.test_set / "SNR_test_set",
|
| 169 |
+
{"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
|
|
|
|
| 170 |
),
|
| 171 |
}
|
|
|
|
| 172 |
return train_loader, val_loader, test_loaders
|
| 173 |
|
| 174 |
def _log_test_results(
|
| 175 |
self,
|
| 176 |
epoch: int,
|
| 177 |
+
test_stats: Dict[str, Dict]
|
|
|
|
| 178 |
) -> None:
|
| 179 |
"""Log test results to TensorBoard.
|
| 180 |
|
| 181 |
+
Creates and logs visualizations for model performance across different test conditions.
|
|
|
|
| 182 |
|
| 183 |
Args:
|
| 184 |
epoch: Current training epoch
|
| 185 |
test_stats: Dictionary of test statistics for the model
|
|
|
|
| 186 |
"""
|
| 187 |
for key in ("DS", "MDS", "SNR"):
|
| 188 |
# Plot test statistics
|
|
|
|
| 190 |
tag=f"MSE vs. {key} (Epoch:{epoch + 1})",
|
| 191 |
figure=get_test_stats_plot(
|
| 192 |
x_name=key,
|
| 193 |
+
stats=[test_stats[key]],
|
| 194 |
+
methods=[self.args.model_name]
|
| 195 |
)
|
| 196 |
)
|
| 197 |
|
| 198 |
# Plot error images
|
| 199 |
+
predicted_channels = self._predict_channels(self.test_loaders[key])
|
|
|
|
|
|
|
|
|
|
| 200 |
self.writer.add_figure(
|
| 201 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
| 202 |
figure=get_error_images(
|
|
|
|
| 209 |
def _run_tests(self, epoch: int) -> None:
|
| 210 |
"""Run tests and log results.
|
| 211 |
|
| 212 |
+
Evaluates the model on all test datasets and logs performance metrics and visualizations.
|
|
|
|
| 213 |
|
| 214 |
Args:
|
| 215 |
epoch: Current training epoch
|
| 216 |
"""
|
| 217 |
+
ds_stats, mds_stats, snr_stats = self._get_all_test_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
test_stats = {
|
| 220 |
"DS": ds_stats,
|
|
|
|
| 222 |
"SNR": snr_stats
|
| 223 |
}
|
| 224 |
|
| 225 |
+
self._log_test_results(epoch, test_stats)
|
| 226 |
|
| 227 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
| 228 |
"""Log final training metrics and hyperparameters.
|
|
|
|
| 242 |
|
| 243 |
try:
|
| 244 |
for key in ("DS", "MDS", "SNR"):
|
| 245 |
+
ds_stats, mds_stats, snr_stats = self._get_all_test_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
ls_stats = {
|
| 247 |
"DS": get_ls_mse_per_folder(self.args.test_set / "DS_test_set"),
|
| 248 |
"MDS": get_ls_mse_per_folder(self.args.test_set / "MDS_test_set"),
|
|
|
|
| 261 |
key,
|
| 262 |
{
|
| 263 |
"LS": ls_stats[key][val],
|
| 264 |
+
self.args.model_name: stats[val]
|
| 265 |
},
|
| 266 |
val
|
| 267 |
)
|
| 268 |
except Exception as e:
|
| 269 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
| 270 |
|
| 271 |
+
def _compute_loss(self, estimated_channel, ideal_channel, loss_fn):
|
| 272 |
+
return loss_fn(
|
| 273 |
+
concat_complex_channel(estimated_channel),
|
| 274 |
+
concat_complex_channel(ideal_channel)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def _forward_pass(self, batch, model):
|
| 278 |
+
estimated_channel, ideal_channel, meta_data = batch
|
| 279 |
+
if hasattr(model, 'name') and model.name in ["fortitran", "MMSE"]:
|
| 280 |
+
h_est_re = model(torch.real(estimated_channel))
|
| 281 |
+
h_est_im = model(torch.imag(estimated_channel))
|
| 282 |
+
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 283 |
+
elif hasattr(model, 'name') and model.name == "adafortitran":
|
| 284 |
+
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 285 |
+
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 286 |
+
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 287 |
+
else:
|
| 288 |
+
raise ValueError(f"Unknown model type: {getattr(model, 'name', type(model))}")
|
| 289 |
+
return estimated_channel, ideal_channel.to(model.device)
|
| 290 |
+
|
| 291 |
+
def _train_epoch(self):
|
| 292 |
+
train_loss = 0.0
|
| 293 |
+
self.model.train()
|
| 294 |
+
for batch in self.train_loader:
|
| 295 |
+
self.optimizer.zero_grad()
|
| 296 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 297 |
+
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 298 |
+
output.backward()
|
| 299 |
+
self.optimizer.step()
|
| 300 |
+
train_loss += (2 * output.item() * batch[0].size(0))
|
| 301 |
+
self.scheduler.step()
|
| 302 |
+
train_loss /= len(self.train_loader.dataset)
|
| 303 |
+
return train_loss
|
| 304 |
+
|
| 305 |
+
def _eval_model(self, eval_dataloader):
|
| 306 |
+
val_loss = 0.0
|
| 307 |
+
self.model.eval()
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
for batch in eval_dataloader:
|
| 310 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 311 |
+
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 312 |
+
val_loss += (2 * output.item() * batch[0].size(0))
|
| 313 |
+
val_loss /= len(eval_dataloader.dataset)
|
| 314 |
+
return val_loss
|
| 315 |
+
|
| 316 |
+
def _predict_channels(self, test_dataloaders):
|
| 317 |
+
channels = {}
|
| 318 |
+
sorted_loaders = sorted(
|
| 319 |
+
test_dataloaders,
|
| 320 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 321 |
+
)
|
| 322 |
+
for name, test_dataloader in sorted_loaders:
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
batch = next(iter(test_dataloader))
|
| 325 |
+
estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
|
| 326 |
+
var, val = name.split("_")
|
| 327 |
+
channels[int(val)] = {
|
| 328 |
+
"estimated_channel": estimated_channels[0],
|
| 329 |
+
"ideal_channel": ideal_channels[0]
|
| 330 |
+
}
|
| 331 |
+
return channels
|
| 332 |
+
|
| 333 |
+
def _get_test_stats(self, test_dataloaders):
|
| 334 |
+
stats = {}
|
| 335 |
+
sorted_loaders = sorted(
|
| 336 |
+
test_dataloaders,
|
| 337 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 338 |
+
)
|
| 339 |
+
for name, test_dataloader in sorted_loaders:
|
| 340 |
+
var, val = name.split("_")
|
| 341 |
+
test_loss = self._eval_model(test_dataloader)
|
| 342 |
+
db_error = to_db(test_loss)
|
| 343 |
+
self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
| 344 |
+
stats[int(val)] = db_error
|
| 345 |
+
return stats
|
| 346 |
+
|
| 347 |
+
def _get_all_test_stats(self):
|
| 348 |
+
ds_stats = self._get_test_stats(self.test_loaders["DS"])
|
| 349 |
+
mds_stats = self._get_test_stats(self.test_loaders["MDS"])
|
| 350 |
+
snr_stats = self._get_test_stats(self.test_loaders["SNR"])
|
| 351 |
+
return ds_stats, mds_stats, snr_stats
|
| 352 |
+
|
| 353 |
def train(self) -> None:
|
| 354 |
"""Execute the training loop.
|
| 355 |
|
|
|
|
| 359 |
- Early stopping when validation loss plateaus
|
| 360 |
- Logging final metrics and results
|
| 361 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
epoch = None
|
| 363 |
+
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
for epoch in pbar:
|
| 365 |
# Training step
|
| 366 |
+
train_loss = self._train_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
|
| 368 |
|
| 369 |
# Validation step
|
| 370 |
+
val_loss = self._eval_model(self.val_loader)
|
| 371 |
self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
|
| 372 |
|
| 373 |
+
# Update progress bar with loss info
|
| 374 |
+
pbar.set_description(
|
| 375 |
+
f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
|
|
|
| 376 |
|
| 377 |
if self.early_stopper.early_stop(val_loss):
|
| 378 |
+
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
|
|
|
|
|
|
|
|
|
| 379 |
break
|
| 380 |
|
| 381 |
# Periodic testing
|
| 382 |
if (epoch + 1) % self.args.test_every_n == 0:
|
| 383 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 384 |
+
pbar.write(message)
|
|
|
|
|
|
|
|
|
|
| 385 |
self._run_tests(epoch)
|
|
|
|
| 386 |
self._log_final_metrics(epoch)
|
| 387 |
self.writer.close()
|
| 388 |
|
| 389 |
|
| 390 |
+
def train(system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments) -> None:
|
| 391 |
"""
|
| 392 |
Train an OFDM channel estimation model.
|
| 393 |
|
|
|
|
| 395 |
with the specified configuration and runs the training process.
|
| 396 |
|
| 397 |
Args:
|
| 398 |
+
system_config: OFDM system configuration dictionary from YAML file
|
| 399 |
+
model_config: OFDM model configuration dictionary from YAML file
|
| 400 |
args: Validated training arguments containing all necessary parameters
|
| 401 |
for model training, including dataset paths, hyperparameters,
|
| 402 |
and logging configuration
|
| 403 |
"""
|
| 404 |
+
trainer = ModelTrainer(system_config, model_config, args)
|
| 405 |
trainer.train()
|
src/models/adafortitran.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from .fortitran import BaseFortiTranEstimator
|
| 2 |
-
from src.config import SystemConfig, ModelConfig
|
| 3 |
|
| 4 |
|
| 5 |
class AdaFortiTranEstimator(BaseFortiTranEstimator):
|
|
|
|
| 1 |
from .fortitran import BaseFortiTranEstimator
|
| 2 |
+
from src.config.schemas import SystemConfig, ModelConfig
|
| 3 |
|
| 4 |
|
| 5 |
class AdaFortiTranEstimator(BaseFortiTranEstimator):
|
src/models/fortitran.py
CHANGED
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
| 3 |
import logging
|
| 4 |
from typing import Tuple, List, Optional
|
| 5 |
|
| 6 |
-
from src.config import SystemConfig, ModelConfig
|
| 7 |
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
|
| 8 |
ChannelAdapter
|
| 9 |
|
|
|
|
| 3 |
import logging
|
| 4 |
from typing import Tuple, List, Optional
|
| 5 |
|
| 6 |
+
from src.config.schemas import SystemConfig, ModelConfig
|
| 7 |
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
|
| 8 |
ChannelAdapter
|
| 9 |
|
src/models/linear.py
CHANGED
|
@@ -27,16 +27,17 @@ class LinearEstimator(nn.Module):
|
|
| 27 |
width (int): number of pilots across OFDM symbols
|
| 28 |
"""
|
| 29 |
|
| 30 |
-
def __init__(self, config: SystemConfig) -> None:
|
| 31 |
"""Initialize the MMSE estimator.
|
| 32 |
|
| 33 |
Args:
|
| 34 |
config: Validated SystemConfig object containing OFDM system parameters
|
|
|
|
| 35 |
"""
|
| 36 |
super().__init__()
|
| 37 |
|
| 38 |
self.config = config
|
| 39 |
-
self.device = torch.device(
|
| 40 |
self.logger = logging.getLogger(__name__)
|
| 41 |
|
| 42 |
# Extract dimensions from validated config
|
|
|
|
| 27 |
width (int): number of pilots across OFDM symbols
|
| 28 |
"""
|
| 29 |
|
| 30 |
+
def __init__(self, config: SystemConfig, device: str = "cpu") -> None:
|
| 31 |
"""Initialize the MMSE estimator.
|
| 32 |
|
| 33 |
Args:
|
| 34 |
config: Validated SystemConfig object containing OFDM system parameters
|
| 35 |
+
device: Device to use for computation (cpu, cuda, etc.)
|
| 36 |
"""
|
| 37 |
super().__init__()
|
| 38 |
|
| 39 |
self.config = config
|
| 40 |
+
self.device = torch.device(device)
|
| 41 |
self.logger = logging.getLogger(__name__)
|
| 42 |
|
| 43 |
# Extract dimensions from validated config
|