Commit
·
4fa78a3
1
Parent(s):
71dbdc8
new README.md and Trainer design
Browse files- README.md +432 -2
- requirements.txt +5 -2
- src/main/parser.py +145 -36
- src/main/train_helpers.py +0 -265
- src/main/trainer.py +472 -110
- src/utils.py +0 -17
README.md
CHANGED
|
@@ -1,7 +1,437 @@
|
|
| 1 |
-
#
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
| 7 |
|
|
|
|
| 1 |
+
# AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
|
| 2 |
|
| 3 |
+
[](LICENSE)
|
| 4 |
+
[](https://www.python.org/)
|
| 5 |
+
[](https://pytorch.org/)
|
| 6 |
|
| 7 |
+
Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
|
| 8 |
+
|
| 9 |
+
## 📖 Overview
|
| 10 |
+
|
| 11 |
+
AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
|
| 12 |
+
|
| 13 |
+
### Key Features
|
| 14 |
+
- **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
|
| 15 |
+
- **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
|
| 16 |
+
- **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
|
| 17 |
+
- **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
|
| 18 |
+
- **🚀 Production Ready**: Comprehensive training pipeline with advanced features
|
| 19 |
+
|
| 20 |
+
## 🏗️ Architecture
|
| 21 |
+
|
| 22 |
+
The project implements three model variants:
|
| 23 |
+
|
| 24 |
+
1. **Linear Estimator**: Simple learned linear transformation baseline
|
| 25 |
+
2. **FortiTran**: Fixed transformer-based channel estimator
|
| 26 |
+
3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
|
| 27 |
+
|
| 28 |
+
### Model Comparison
|
| 29 |
+
|
| 30 |
+
| Model | Channel Adaptation | Complexity | Performance |
|
| 31 |
+
|-------|-------------------|------------|-------------|
|
| 32 |
+
| Linear | ❌ | Low | Baseline |
|
| 33 |
+
| FortiTran | ❌ | Medium | Good |
|
| 34 |
+
| AdaFortiTran | ✅ | High | **Best** |
|
| 35 |
+
|
| 36 |
+
## 🚀 Quick Start
|
| 37 |
+
|
| 38 |
+
### Installation
|
| 39 |
+
|
| 40 |
+
1. **Clone the repository**:
|
| 41 |
+
```bash
|
| 42 |
+
git clone https://github.com/your-username/AdaFortiTran.git
|
| 43 |
+
cd AdaFortiTran
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. **Install dependencies**:
|
| 47 |
+
```bash
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
3. **Verify installation**:
|
| 52 |
+
```bash
|
| 53 |
+
python -c "import torch; print(f'PyTorch {torch.__version__}')"
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Basic Training
|
| 57 |
+
|
| 58 |
+
Train an AdaFortiTran model with default settings:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python src/main.py \
|
| 62 |
+
--model_name adafortitran \
|
| 63 |
+
--system_config_path config/system_config.yaml \
|
| 64 |
+
--model_config_path config/adafortitran.yaml \
|
| 65 |
+
--train_set data/train \
|
| 66 |
+
--val_set data/val \
|
| 67 |
+
--test_set data/test \
|
| 68 |
+
--exp_id my_experiment
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Advanced Training
|
| 72 |
+
|
| 73 |
+
Use all available features for optimal performance:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python src/main.py \
|
| 77 |
+
--model_name adafortitran \
|
| 78 |
+
--system_config_path config/system_config.yaml \
|
| 79 |
+
--model_config_path config/adafortitran.yaml \
|
| 80 |
+
--train_set data/train \
|
| 81 |
+
--val_set data/val \
|
| 82 |
+
--test_set data/test \
|
| 83 |
+
--exp_id advanced_experiment \
|
| 84 |
+
--batch_size 128 \
|
| 85 |
+
--lr 5e-4 \
|
| 86 |
+
--max_epoch 100 \
|
| 87 |
+
--patience 10 \
|
| 88 |
+
--weight_decay 1e-4 \
|
| 89 |
+
--gradient_clip_val 1.0 \
|
| 90 |
+
--use_mixed_precision \
|
| 91 |
+
--save_every_n_epochs 5 \
|
| 92 |
+
--num_workers 8 \
|
| 93 |
+
--test_every_n 5
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## 📁 Project Structure
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
AdaFortiTran/
|
| 100 |
+
├── config/ # Configuration files
|
| 101 |
+
│ ├── system_config.yaml # OFDM system parameters
|
| 102 |
+
│ ├── adafortitran.yaml # AdaFortiTran model config
|
| 103 |
+
│ ├── fortitran.yaml # FortiTran model config
|
| 104 |
+
│ └── linear.yaml # Linear model config
|
| 105 |
+
├── data/ # Dataset directory
|
| 106 |
+
│ ├── train/ # Training data
|
| 107 |
+
│ ├── val/ # Validation data
|
| 108 |
+
│ └── test/ # Test data (DS, MDS, SNR sets)
|
| 109 |
+
├── src/ # Source code
|
| 110 |
+
│ ├── main/ # Training pipeline
|
| 111 |
+
│ │ ├── trainer.py # Enhanced ModelTrainer
|
| 112 |
+
│ │ └── parser.py # Command-line argument parser
|
| 113 |
+
│ ├── models/ # Model implementations
|
| 114 |
+
│ │ ├── adafortitran.py # AdaFortiTran model
|
| 115 |
+
│ │ ├── fortitran.py # FortiTran model
|
| 116 |
+
│ │ ├── linear.py # Linear model
|
| 117 |
+
│ │ └── blocks/ # Model building blocks
|
| 118 |
+
│ ├── data/ # Data loading
|
| 119 |
+
│ │ └── dataset.py # Dataset and DataLoader classes
|
| 120 |
+
│ ├── config/ # Configuration management
|
| 121 |
+
│ │ ├── config_loader.py # YAML configuration loader
|
| 122 |
+
│ │ └── schemas.py # Pydantic validation schemas
|
| 123 |
+
│ └── utils.py # Utility functions
|
| 124 |
+
├── requirements.txt # Python dependencies
|
| 125 |
+
├── README.md # This file
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## ⚙️ Configuration
|
| 129 |
+
|
| 130 |
+
### System Configuration (`config/system_config.yaml`)
|
| 131 |
+
|
| 132 |
+
Defines OFDM system parameters:
|
| 133 |
+
|
| 134 |
+
```yaml
|
| 135 |
+
ofdm:
|
| 136 |
+
num_scs: 120 # Number of subcarriers
|
| 137 |
+
num_symbols: 14 # Number of OFDM symbols
|
| 138 |
+
|
| 139 |
+
pilot:
|
| 140 |
+
num_scs: 12 # Number of pilot subcarriers
|
| 141 |
+
num_symbols: 2 # Number of pilot symbols
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
### Model Configuration (`config/adafortitran.yaml`)
|
| 145 |
+
|
| 146 |
+
Defines model architecture parameters:
|
| 147 |
+
|
| 148 |
+
```yaml
|
| 149 |
+
model_type: 'adafortitran'
|
| 150 |
+
patch_size: [3, 2] # Patch dimensions
|
| 151 |
+
num_layers: 6 # Transformer layers
|
| 152 |
+
model_dim: 128 # Model dimension
|
| 153 |
+
num_head: 4 # Attention heads
|
| 154 |
+
activation: 'gelu' # Activation function
|
| 155 |
+
dropout: 0.1 # Dropout rate
|
| 156 |
+
max_seq_len: 512 # Maximum sequence length
|
| 157 |
+
pos_encoding_type: 'learnable' # Positional encoding
|
| 158 |
+
channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
|
| 159 |
+
adaptive_token_length: 6 # Adaptive token length
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## 🎯 Training Features
|
| 163 |
+
|
| 164 |
+
### Advanced Training Options
|
| 165 |
+
|
| 166 |
+
| Feature | Description | Default |
|
| 167 |
+
|---------|-------------|---------|
|
| 168 |
+
| `--use_mixed_precision` | Enable mixed precision training | False |
|
| 169 |
+
| `--gradient_clip_val` | Gradient clipping value | None |
|
| 170 |
+
| `--weight_decay` | Weight decay for optimizer | 0.0 |
|
| 171 |
+
| `--save_checkpoints` | Enable model checkpointing | True |
|
| 172 |
+
| `--save_best_only` | Save only best model | True |
|
| 173 |
+
| `--resume_from_checkpoint` | Resume from checkpoint | None |
|
| 174 |
+
| `--num_workers` | Data loading workers | 4 |
|
| 175 |
+
| `--pin_memory` | Pin memory for GPU | True |
|
| 176 |
+
|
| 177 |
+
### Callback System
|
| 178 |
+
|
| 179 |
+
The training pipeline includes an extensible callback system:
|
| 180 |
+
|
| 181 |
+
- **TensorBoard Logging**: Automatic metric tracking and visualization
|
| 182 |
+
- **Checkpoint Management**: Flexible checkpoint saving strategies
|
| 183 |
+
- **Custom Callbacks**: Easy to add new logging or monitoring systems
|
| 184 |
+
|
| 185 |
+
### Performance Optimizations
|
| 186 |
+
|
| 187 |
+
- **Mixed Precision Training**: Faster training on modern GPUs
|
| 188 |
+
- **Optimized Data Loading**: Configurable workers and memory pinning
|
| 189 |
+
- **Gradient Clipping**: Stable training with configurable clipping
|
| 190 |
+
- **Early Stopping**: Automatic training termination on plateau
|
| 191 |
+
|
| 192 |
+
## 📊 Dataset Format
|
| 193 |
+
|
| 194 |
+
### Expected File Structure
|
| 195 |
+
|
| 196 |
+
```
|
| 197 |
+
data/
|
| 198 |
+
├── train/
|
| 199 |
+
│ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 200 |
+
│ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 201 |
+
│ └── ...
|
| 202 |
+
├── val/
|
| 203 |
+
│ └── ...
|
| 204 |
+
└── test/
|
| 205 |
+
├── DS_test_set/ # Delay Spread tests
|
| 206 |
+
│ ├── DS_50/
|
| 207 |
+
│ ├── DS_100/
|
| 208 |
+
│ └── ...
|
| 209 |
+
├── SNR_test_set/ # SNR tests
|
| 210 |
+
│ ├── SNR_10/
|
| 211 |
+
│ ├── SNR_20/
|
| 212 |
+
│ └── ...
|
| 213 |
+
└── MDS_test_set/ # Multi-Doppler tests
|
| 214 |
+
├── DOP_200/
|
| 215 |
+
├── DOP_400/
|
| 216 |
+
└── ...
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### File Naming Convention
|
| 220 |
+
|
| 221 |
+
Files must follow the pattern:
|
| 222 |
+
```
|
| 223 |
+
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
|
| 227 |
+
|
| 228 |
+
### Data Format
|
| 229 |
+
|
| 230 |
+
Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
|
| 231 |
+
- `H[:, :, 0]`: Ground truth channel (complex values)
|
| 232 |
+
- `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
|
| 233 |
+
- `H[:, :, 2]`: Reserved for future use
|
| 234 |
+
|
| 235 |
+
## 🔧 Usage Examples
|
| 236 |
+
|
| 237 |
+
### Training Different Models
|
| 238 |
+
|
| 239 |
+
**Linear Estimator**:
|
| 240 |
+
```bash
|
| 241 |
+
python src/main.py \
|
| 242 |
+
--model_name linear \
|
| 243 |
+
--system_config_path config/system_config.yaml \
|
| 244 |
+
--model_config_path config/linear.yaml \
|
| 245 |
+
--train_set data/train \
|
| 246 |
+
--val_set data/val \
|
| 247 |
+
--test_set data/test \
|
| 248 |
+
--exp_id linear_baseline
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
**FortiTran**:
|
| 252 |
+
```bash
|
| 253 |
+
python src/main.py \
|
| 254 |
+
--model_name fortitran \
|
| 255 |
+
--system_config_path config/system_config.yaml \
|
| 256 |
+
--model_config_path config/fortitran.yaml \
|
| 257 |
+
--train_set data/train \
|
| 258 |
+
--val_set data/val \
|
| 259 |
+
--test_set data/test \
|
| 260 |
+
--exp_id fortitran_experiment
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
**AdaFortiTran**:
|
| 264 |
+
```bash
|
| 265 |
+
python src/main.py \
|
| 266 |
+
--model_name adafortitran \
|
| 267 |
+
--system_config_path config/system_config.yaml \
|
| 268 |
+
--model_config_path config/adafortitran.yaml \
|
| 269 |
+
--train_set data/train \
|
| 270 |
+
--val_set data/val \
|
| 271 |
+
--test_set data/test \
|
| 272 |
+
--exp_id adafortitran_experiment
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
### Resume Training
|
| 276 |
+
|
| 277 |
+
```bash
|
| 278 |
+
python src/main.py \
|
| 279 |
+
--model_name adafortitran \
|
| 280 |
+
--system_config_path config/system_config.yaml \
|
| 281 |
+
--model_config_path config/adafortitran.yaml \
|
| 282 |
+
--train_set data/train \
|
| 283 |
+
--val_set data/val \
|
| 284 |
+
--test_set data/test \
|
| 285 |
+
--exp_id resumed_experiment \
|
| 286 |
+
--resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
### Hyperparameter Tuning
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
python src/main.py \
|
| 293 |
+
--model_name adafortitran \
|
| 294 |
+
--system_config_path config/system_config.yaml \
|
| 295 |
+
--model_config_path config/adafortitran.yaml \
|
| 296 |
+
--train_set data/train \
|
| 297 |
+
--val_set data/val \
|
| 298 |
+
--test_set data/test \
|
| 299 |
+
--exp_id hyperparameter_tuning \
|
| 300 |
+
--batch_size 64 \
|
| 301 |
+
--lr 1e-3 \
|
| 302 |
+
--max_epoch 50 \
|
| 303 |
+
--patience 5 \
|
| 304 |
+
--weight_decay 1e-5 \
|
| 305 |
+
--gradient_clip_val 0.5 \
|
| 306 |
+
--use_mixed_precision \
|
| 307 |
+
--test_every_n 5
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
## 📈 Monitoring and Logging
|
| 311 |
+
|
| 312 |
+
### TensorBoard Integration
|
| 313 |
+
|
| 314 |
+
Training automatically logs metrics to TensorBoard:
|
| 315 |
+
|
| 316 |
+
```bash
|
| 317 |
+
tensorboard --logdir runs/
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
Available metrics:
|
| 321 |
+
- Training/validation loss
|
| 322 |
+
- Learning rate
|
| 323 |
+
- Test performance across conditions
|
| 324 |
+
- Error visualizations
|
| 325 |
+
- Model hyperparameters
|
| 326 |
+
|
| 327 |
+
### Log Files
|
| 328 |
+
|
| 329 |
+
Training logs are saved to:
|
| 330 |
+
- `logs/training_{exp_id}.log`: Python logging output
|
| 331 |
+
- `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
|
| 332 |
+
|
| 333 |
+
## 🧪 Testing and Evaluation
|
| 334 |
+
|
| 335 |
+
### Automatic Testing
|
| 336 |
+
|
| 337 |
+
The training pipeline automatically evaluates models on:
|
| 338 |
+
- **DS (Delay Spread)**: Varying delay spread conditions
|
| 339 |
+
- **SNR**: Different signal-to-noise ratios
|
| 340 |
+
- **MDS (Multi-Doppler)**: Various Doppler shift scenarios
|
| 341 |
+
|
| 342 |
+
### Manual Evaluation
|
| 343 |
+
|
| 344 |
+
```python
|
| 345 |
+
from src.models import AdaFortiTranEstimator
|
| 346 |
+
from src.config import load_config
|
| 347 |
+
|
| 348 |
+
# Load configurations
|
| 349 |
+
system_config, model_config = load_config(
|
| 350 |
+
'config/system_config.yaml',
|
| 351 |
+
'config/adafortitran.yaml'
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Initialize model
|
| 355 |
+
model = AdaFortiTranEstimator(system_config, model_config)
|
| 356 |
+
|
| 357 |
+
# Load checkpoint
|
| 358 |
+
checkpoint = torch.load('checkpoint.pt')
|
| 359 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 360 |
+
|
| 361 |
+
# Evaluate
|
| 362 |
+
model.eval()
|
| 363 |
+
# ... evaluation code
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
## 🔬 Research and Development
|
| 367 |
+
|
| 368 |
+
### Adding Custom Callbacks
|
| 369 |
+
|
| 370 |
+
```python
|
| 371 |
+
from src.main.trainer import Callback, TrainingMetrics
|
| 372 |
+
|
| 373 |
+
class CustomCallback(Callback):
|
| 374 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
| 375 |
+
# Custom logic here
|
| 376 |
+
print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
|
| 377 |
+
```
|
| 378 |
+
|
| 379 |
+
### Extending Models
|
| 380 |
+
|
| 381 |
+
The modular architecture makes it easy to add new model variants:
|
| 382 |
+
|
| 383 |
+
```python
|
| 384 |
+
from src.models.fortitran import BaseFortiTranEstimator
|
| 385 |
+
|
| 386 |
+
class CustomEstimator(BaseFortiTranEstimator):
|
| 387 |
+
def __init__(self, system_config, model_config):
|
| 388 |
+
super().__init__(system_config, model_config, use_channel_adaptation=True)
|
| 389 |
+
# Add custom components
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
## 🐛 Troubleshooting
|
| 393 |
+
|
| 394 |
+
### Common Issues
|
| 395 |
+
|
| 396 |
+
**CUDA Out of Memory**:
|
| 397 |
+
- Reduce batch size: `--batch_size 32`
|
| 398 |
+
- Enable mixed precision: `--use_mixed_precision`
|
| 399 |
+
- Reduce number of workers: `--num_workers 2`
|
| 400 |
+
|
| 401 |
+
**Slow Training**:
|
| 402 |
+
- Increase number of workers: `--num_workers 8`
|
| 403 |
+
- Enable pin memory: `--pin_memory`
|
| 404 |
+
- Use mixed precision: `--use_mixed_precision`
|
| 405 |
+
|
| 406 |
+
**Poor Convergence**:
|
| 407 |
+
- Adjust learning rate: `--lr 1e-4`
|
| 408 |
+
- Add gradient clipping: `--gradient_clip_val 1.0`
|
| 409 |
+
- Increase patience: `--patience 10`
|
| 410 |
+
|
| 411 |
+
### Getting Help
|
| 412 |
+
|
| 413 |
+
1. Check the logs in `logs/training_{exp_id}.log`
|
| 414 |
+
2. Verify dataset format matches requirements
|
| 415 |
+
3. Ensure all dependencies are installed correctly
|
| 416 |
+
4. Check TensorBoard for training curves
|
| 417 |
+
|
| 418 |
+
## 📚 Citation
|
| 419 |
+
|
| 420 |
+
If you use this code in your research, please cite:
|
| 421 |
+
|
| 422 |
+
```bibtex
|
| 423 |
+
@misc{guler2025adafortitranadaptivetransformermodel,
|
| 424 |
+
title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
|
| 425 |
+
author={Berkay Guler and Hamid Jafarkhani},
|
| 426 |
+
year={2025},
|
| 427 |
+
eprint={2505.09076},
|
| 428 |
+
archivePrefix={arXiv},
|
| 429 |
+
primaryClass={cs.LG},
|
| 430 |
+
url={https://arxiv.org/abs/2505.09076},
|
| 431 |
+
}
|
| 432 |
+
```
|
| 433 |
+
|
| 434 |
+
## 📄 License
|
| 435 |
|
| 436 |
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
| 437 |
|
requirements.txt
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
torch
|
| 2 |
pydantic
|
| 3 |
-
|
| 4 |
scipy
|
| 5 |
-
tqdm
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
torch
|
| 2 |
pydantic
|
| 3 |
+
pyyaml
|
| 4 |
scipy
|
| 5 |
+
tqdm
|
| 6 |
+
matplotlib
|
| 7 |
+
prettytable
|
| 8 |
+
tensorboard
|
src/main/parser.py
CHANGED
|
@@ -10,7 +10,7 @@ of training runs.
|
|
| 10 |
from pathlib import Path
|
| 11 |
import argparse
|
| 12 |
from pydantic import BaseModel, Field, model_validator
|
| 13 |
-
from typing import Self
|
| 14 |
|
| 15 |
|
| 16 |
class TrainingArguments(BaseModel):
|
|
@@ -41,9 +41,22 @@ class TrainingArguments(BaseModel):
|
|
| 41 |
lr: Learning rate for optimizer
|
| 42 |
max_epoch: Maximum number of training epochs
|
| 43 |
patience: Early stopping patience in epochs
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Evaluation
|
| 46 |
test_every_n: Number of epochs between test evaluations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
|
| 49 |
# Model Configuration
|
|
@@ -67,10 +80,23 @@ class TrainingArguments(BaseModel):
|
|
| 67 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
| 68 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
| 69 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# Evaluation
|
| 72 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
@model_validator(mode='after')
|
| 75 |
def validate_paths(self) -> Self:
|
| 76 |
"""Validate path-related arguments.
|
|
@@ -92,6 +118,13 @@ class TrainingArguments(BaseModel):
|
|
| 92 |
if not self.model_config_path.suffix == '.yaml':
|
| 93 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return self
|
| 96 |
|
| 97 |
|
|
@@ -161,58 +194,134 @@ def parse_arguments() -> TrainingArguments:
|
|
| 161 |
help='Experiment identifier for log folder naming'
|
| 162 |
)
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
'--
|
| 168 |
-
type=str,
|
| 169 |
-
default="INFO",
|
| 170 |
-
help='Logger level for python logging module'
|
| 171 |
-
)
|
| 172 |
-
optional.add_argument(
|
| 173 |
-
'--tensorboard_log_dir',
|
| 174 |
-
type=Path,
|
| 175 |
-
default="runs",
|
| 176 |
-
help='Directory for tensorboard logs'
|
| 177 |
-
)
|
| 178 |
-
optional.add_argument(
|
| 179 |
-
'--python_log_dir',
|
| 180 |
-
type=Path,
|
| 181 |
-
default="logs",
|
| 182 |
-
help='Directory for python logging files'
|
| 183 |
-
)
|
| 184 |
-
optional.add_argument(
|
| 185 |
-
'--test_every_n',
|
| 186 |
type=int,
|
| 187 |
-
default=
|
| 188 |
-
help='
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
-
|
| 191 |
'--max_epoch',
|
| 192 |
type=int,
|
| 193 |
default=10,
|
| 194 |
help='Maximum number of training epochs'
|
| 195 |
)
|
| 196 |
-
|
| 197 |
'--patience',
|
| 198 |
type=int,
|
| 199 |
default=3,
|
| 200 |
help='Early stopping patience (epochs)'
|
| 201 |
)
|
| 202 |
-
|
| 203 |
-
'--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
type=int,
|
| 205 |
-
default=
|
| 206 |
-
help='
|
| 207 |
)
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
args = parser.parse_args()
|
| 218 |
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
import argparse
|
| 12 |
from pydantic import BaseModel, Field, model_validator
|
| 13 |
+
from typing import Self, Optional
|
| 14 |
|
| 15 |
|
| 16 |
class TrainingArguments(BaseModel):
|
|
|
|
| 41 |
lr: Learning rate for optimizer
|
| 42 |
max_epoch: Maximum number of training epochs
|
| 43 |
patience: Early stopping patience in epochs
|
| 44 |
+
weight_decay: Weight decay for optimizer
|
| 45 |
+
gradient_clip_val: Gradient clipping value
|
| 46 |
+
use_mixed_precision: Whether to use mixed precision training
|
| 47 |
|
| 48 |
# Evaluation
|
| 49 |
test_every_n: Number of epochs between test evaluations
|
| 50 |
+
|
| 51 |
+
# Checkpointing
|
| 52 |
+
save_checkpoints: Whether to save model checkpoints
|
| 53 |
+
save_best_only: Whether to save only the best model
|
| 54 |
+
save_every_n_epochs: Save checkpoint every N epochs
|
| 55 |
+
resume_from_checkpoint: Path to checkpoint to resume from
|
| 56 |
+
|
| 57 |
+
# Data Loading
|
| 58 |
+
num_workers: Number of data loading workers
|
| 59 |
+
pin_memory: Whether to pin memory for faster GPU transfer
|
| 60 |
"""
|
| 61 |
|
| 62 |
# Model Configuration
|
|
|
|
| 80 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
| 81 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
| 82 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
| 83 |
+
weight_decay: float = Field(default=0.0, ge=0.0, description="Weight decay for optimizer")
|
| 84 |
+
gradient_clip_val: Optional[float] = Field(default=None, gt=0, description="Gradient clipping value")
|
| 85 |
+
use_mixed_precision: bool = Field(default=False, description="Whether to use mixed precision training")
|
| 86 |
|
| 87 |
# Evaluation
|
| 88 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
| 89 |
|
| 90 |
+
# Checkpointing
|
| 91 |
+
save_checkpoints: bool = Field(default=True, description="Whether to save model checkpoints")
|
| 92 |
+
save_best_only: bool = Field(default=True, description="Whether to save only the best model")
|
| 93 |
+
save_every_n_epochs: Optional[int] = Field(default=None, gt=0, description="Save checkpoint every N epochs")
|
| 94 |
+
resume_from_checkpoint: Optional[Path] = Field(default=None, description="Path to checkpoint to resume from")
|
| 95 |
+
|
| 96 |
+
# Data Loading
|
| 97 |
+
num_workers: int = Field(default=4, ge=0, description="Number of data loading workers")
|
| 98 |
+
pin_memory: bool = Field(default=True, description="Whether to pin memory for faster GPU transfer")
|
| 99 |
+
|
| 100 |
@model_validator(mode='after')
|
| 101 |
def validate_paths(self) -> Self:
|
| 102 |
"""Validate path-related arguments.
|
|
|
|
| 118 |
if not self.model_config_path.suffix == '.yaml':
|
| 119 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
| 120 |
|
| 121 |
+
# Validate checkpoint path if provided
|
| 122 |
+
if self.resume_from_checkpoint is not None:
|
| 123 |
+
if not self.resume_from_checkpoint.exists():
|
| 124 |
+
raise ValueError(f"Checkpoint file not found: {self.resume_from_checkpoint}")
|
| 125 |
+
if not self.resume_from_checkpoint.suffix == '.pt':
|
| 126 |
+
raise ValueError(f"Checkpoint file must be a .pt file: {self.resume_from_checkpoint}")
|
| 127 |
+
|
| 128 |
return self
|
| 129 |
|
| 130 |
|
|
|
|
| 194 |
help='Experiment identifier for log folder naming'
|
| 195 |
)
|
| 196 |
|
| 197 |
+
# Training hyperparameters
|
| 198 |
+
training = parser.add_argument_group('training hyperparameters')
|
| 199 |
+
training.add_argument(
|
| 200 |
+
'--batch_size',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
type=int,
|
| 202 |
+
default=64,
|
| 203 |
+
help='Training batch size'
|
| 204 |
+
)
|
| 205 |
+
training.add_argument(
|
| 206 |
+
'--lr',
|
| 207 |
+
type=float,
|
| 208 |
+
default=1e-3,
|
| 209 |
+
help='Initial learning rate'
|
| 210 |
)
|
| 211 |
+
training.add_argument(
|
| 212 |
'--max_epoch',
|
| 213 |
type=int,
|
| 214 |
default=10,
|
| 215 |
help='Maximum number of training epochs'
|
| 216 |
)
|
| 217 |
+
training.add_argument(
|
| 218 |
'--patience',
|
| 219 |
type=int,
|
| 220 |
default=3,
|
| 221 |
help='Early stopping patience (epochs)'
|
| 222 |
)
|
| 223 |
+
training.add_argument(
|
| 224 |
+
'--weight_decay',
|
| 225 |
+
type=float,
|
| 226 |
+
default=0.0,
|
| 227 |
+
help='Weight decay for optimizer'
|
| 228 |
+
)
|
| 229 |
+
training.add_argument(
|
| 230 |
+
'--gradient_clip_val',
|
| 231 |
+
type=float,
|
| 232 |
+
default=None,
|
| 233 |
+
help='Gradient clipping value (disabled if not specified)'
|
| 234 |
+
)
|
| 235 |
+
training.add_argument(
|
| 236 |
+
'--use_mixed_precision',
|
| 237 |
+
action='store_true',
|
| 238 |
+
help='Use mixed precision training (requires PyTorch >= 1.6)'
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Evaluation settings
|
| 242 |
+
evaluation = parser.add_argument_group('evaluation settings')
|
| 243 |
+
evaluation.add_argument(
|
| 244 |
+
'--test_every_n',
|
| 245 |
type=int,
|
| 246 |
+
default=10,
|
| 247 |
+
help='Test model every N epochs'
|
| 248 |
)
|
| 249 |
|
| 250 |
+
# Checkpointing settings
|
| 251 |
+
checkpointing = parser.add_argument_group('checkpointing settings')
|
| 252 |
+
checkpointing.add_argument(
|
| 253 |
+
'--save_checkpoints',
|
| 254 |
+
action='store_true',
|
| 255 |
+
default=True,
|
| 256 |
+
help='Save model checkpoints'
|
| 257 |
+
)
|
| 258 |
+
checkpointing.add_argument(
|
| 259 |
+
'--no_save_checkpoints',
|
| 260 |
+
action='store_false',
|
| 261 |
+
dest='save_checkpoints',
|
| 262 |
+
help='Disable saving model checkpoints'
|
| 263 |
+
)
|
| 264 |
+
checkpointing.add_argument(
|
| 265 |
+
'--save_best_only',
|
| 266 |
+
action='store_true',
|
| 267 |
+
default=True,
|
| 268 |
+
help='Save only the best model based on validation loss'
|
| 269 |
+
)
|
| 270 |
+
checkpointing.add_argument(
|
| 271 |
+
'--save_every_n_epochs',
|
| 272 |
+
type=int,
|
| 273 |
+
default=None,
|
| 274 |
+
help='Save checkpoint every N epochs (in addition to best model)'
|
| 275 |
+
)
|
| 276 |
+
checkpointing.add_argument(
|
| 277 |
+
'--resume_from_checkpoint',
|
| 278 |
+
type=Path,
|
| 279 |
+
default=None,
|
| 280 |
+
help='Path to checkpoint file to resume training from'
|
| 281 |
)
|
| 282 |
|
| 283 |
+
# Data loading settings
|
| 284 |
+
data_loading = parser.add_argument_group('data loading settings')
|
| 285 |
+
data_loading.add_argument(
|
| 286 |
+
'--num_workers',
|
| 287 |
+
type=int,
|
| 288 |
+
default=4,
|
| 289 |
+
help='Number of data loading workers'
|
| 290 |
+
)
|
| 291 |
+
data_loading.add_argument(
|
| 292 |
+
'--pin_memory',
|
| 293 |
+
action='store_true',
|
| 294 |
+
default=True,
|
| 295 |
+
help='Pin memory for faster GPU transfer'
|
| 296 |
+
)
|
| 297 |
+
data_loading.add_argument(
|
| 298 |
+
'--no_pin_memory',
|
| 299 |
+
action='store_false',
|
| 300 |
+
dest='pin_memory',
|
| 301 |
+
help='Disable pin memory'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Logging settings
|
| 305 |
+
logging_group = parser.add_argument_group('logging settings')
|
| 306 |
+
logging_group.add_argument(
|
| 307 |
+
'--python_log_level',
|
| 308 |
+
type=str,
|
| 309 |
+
default="INFO",
|
| 310 |
+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
| 311 |
+
help='Logger level for python logging module'
|
| 312 |
+
)
|
| 313 |
+
logging_group.add_argument(
|
| 314 |
+
'--tensorboard_log_dir',
|
| 315 |
+
type=Path,
|
| 316 |
+
default="runs",
|
| 317 |
+
help='Directory for tensorboard logs'
|
| 318 |
+
)
|
| 319 |
+
logging_group.add_argument(
|
| 320 |
+
'--python_log_dir',
|
| 321 |
+
type=Path,
|
| 322 |
+
default="logs",
|
| 323 |
+
help='Directory for python logging files'
|
| 324 |
+
)
|
| 325 |
|
| 326 |
args = parser.parse_args()
|
| 327 |
|
src/main/train_helpers.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Training helper functions for OFDM channel estimation models.
|
| 3 |
-
|
| 4 |
-
This module provides utility functions for training, evaluating, and testing
|
| 5 |
-
deep learning models for OFDM channel estimation tasks. It includes functions
|
| 6 |
-
for performing training epochs, model evaluation, prediction generation,
|
| 7 |
-
and performance statistics calculation across different test conditions.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from typing import Dict, List, Tuple, Union, Callable
|
| 11 |
-
import torch
|
| 12 |
-
from torch import nn
|
| 13 |
-
from torch.utils.data import DataLoader
|
| 14 |
-
from torch.optim import Optimizer
|
| 15 |
-
from torch.optim.lr_scheduler import ExponentialLR
|
| 16 |
-
from src.utils import to_db, concat_complex_channel
|
| 17 |
-
|
| 18 |
-
# Type aliases
|
| 19 |
-
ComplexTensor = torch.Tensor # Complex tensor
|
| 20 |
-
BatchType = Tuple[ComplexTensor, ComplexTensor, Union[Dict, None]]
|
| 21 |
-
TestDataLoadersType = List[Tuple[str, DataLoader]]
|
| 22 |
-
StatsType = Dict[int, float]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_all_test_stats(
|
| 26 |
-
model: nn.Module,
|
| 27 |
-
test_dataloaders: Dict[str, TestDataLoadersType],
|
| 28 |
-
loss_fn: Callable
|
| 29 |
-
) -> Tuple[StatsType, StatsType, StatsType]:
|
| 30 |
-
"""
|
| 31 |
-
Evaluate model on all test datasets.
|
| 32 |
-
|
| 33 |
-
Calculates performance statistics (MSE in dB) for a model across different
|
| 34 |
-
test conditions: Delay Spread (DS), Max Doppler Shift (MDS), and
|
| 35 |
-
Signal-to-Noise Ratio (SNR).
|
| 36 |
-
|
| 37 |
-
Args:
|
| 38 |
-
model: Model to evaluate
|
| 39 |
-
test_dataloaders: Dictionary containing DataLoader objects for test sets:
|
| 40 |
-
- "DS": Delay Spread test set
|
| 41 |
-
- "MDS": Max Doppler Shift test set
|
| 42 |
-
- "SNR": Signal-to-Noise Ratio test set
|
| 43 |
-
loss_fn: Loss function for evaluation
|
| 44 |
-
|
| 45 |
-
Returns:
|
| 46 |
-
Tuple containing statistics (MSE in dB) for DS, MDS, and SNR test sets,
|
| 47 |
-
where each set of statistics is a dictionary mapping parameter values to MSE
|
| 48 |
-
"""
|
| 49 |
-
ds_stats = get_test_stats(model, test_dataloaders["DS"], loss_fn)
|
| 50 |
-
mds_stats = get_test_stats(model, test_dataloaders["MDS"], loss_fn)
|
| 51 |
-
snr_stats = get_test_stats(model, test_dataloaders["SNR"], loss_fn)
|
| 52 |
-
return ds_stats, mds_stats, snr_stats
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def get_test_stats(
|
| 56 |
-
model: nn.Module,
|
| 57 |
-
test_dataloaders: TestDataLoadersType,
|
| 58 |
-
loss_fn: Callable
|
| 59 |
-
) -> StatsType:
|
| 60 |
-
"""
|
| 61 |
-
Evaluate model on provided test dataloaders.
|
| 62 |
-
|
| 63 |
-
Calculates performance statistics (MSE in dB) for a model on a
|
| 64 |
-
specific set of test conditions.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
model: Model to evaluate
|
| 68 |
-
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
| 69 |
-
where names are in format "parameter_value"
|
| 70 |
-
loss_fn: Loss function for evaluation
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
Dictionary mapping test parameter values (as integers) to MSE values in dB
|
| 74 |
-
"""
|
| 75 |
-
stats: StatsType = {}
|
| 76 |
-
sorted_loaders = sorted(
|
| 77 |
-
test_dataloaders,
|
| 78 |
-
key=lambda x: int(x[0].split("_")[1])
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
for name, test_dataloader in sorted_loaders:
|
| 82 |
-
var, val = name.split("_")
|
| 83 |
-
test_loss = eval_model(model, test_dataloader, loss_fn)
|
| 84 |
-
db_error = to_db(test_loss)
|
| 85 |
-
print(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
| 86 |
-
stats[int(val)] = db_error
|
| 87 |
-
|
| 88 |
-
return stats
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def eval_model(
|
| 92 |
-
model: nn.Module,
|
| 93 |
-
eval_dataloader: DataLoader,
|
| 94 |
-
loss_fn: Callable
|
| 95 |
-
) -> float:
|
| 96 |
-
"""
|
| 97 |
-
Evaluate model on given dataloader.
|
| 98 |
-
|
| 99 |
-
Calculates the average loss for a model on a dataset without
|
| 100 |
-
performing parameter updates.
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
model: Model to evaluate
|
| 104 |
-
eval_dataloader: DataLoader containing evaluation data
|
| 105 |
-
loss_fn: Loss function for computing error
|
| 106 |
-
|
| 107 |
-
Returns:
|
| 108 |
-
Average validation loss (adjusted for complex values)
|
| 109 |
-
|
| 110 |
-
Notes:
|
| 111 |
-
Loss is multiplied by 2 to account for complex-valued matrices being
|
| 112 |
-
represented as real-valued matrices of double size.
|
| 113 |
-
"""
|
| 114 |
-
val_loss = 0.0
|
| 115 |
-
model.eval()
|
| 116 |
-
|
| 117 |
-
with torch.no_grad():
|
| 118 |
-
for batch in eval_dataloader:
|
| 119 |
-
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
| 120 |
-
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 121 |
-
val_loss += (2 * output.item() * batch[0].size(0))
|
| 122 |
-
|
| 123 |
-
val_loss /= sum(len(batch[0]) for batch in eval_dataloader)
|
| 124 |
-
return val_loss
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def predict_channels(
|
| 128 |
-
model: nn.Module,
|
| 129 |
-
test_dataloaders: TestDataLoadersType
|
| 130 |
-
) -> Dict[int, Dict[str, ComplexTensor]]:
|
| 131 |
-
"""
|
| 132 |
-
Generate channel predictions for test datasets.
|
| 133 |
-
|
| 134 |
-
Creates predictions for a sample from each test dataset to enable
|
| 135 |
-
visualization and error analysis.
|
| 136 |
-
|
| 137 |
-
Args:
|
| 138 |
-
model: Model to use for predictions
|
| 139 |
-
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
| 140 |
-
where names are in format "parameter_value"
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
Dictionary mapping test parameter values (as integers) to dictionaries containing
|
| 144 |
-
estimated and ideal channels for a single sample
|
| 145 |
-
"""
|
| 146 |
-
channels: Dict[int, Dict[str, ComplexTensor]] = {}
|
| 147 |
-
sorted_loaders = sorted(
|
| 148 |
-
test_dataloaders,
|
| 149 |
-
key=lambda x: int(x[0].split("_")[1])
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
for name, test_dataloader in sorted_loaders:
|
| 153 |
-
with torch.no_grad():
|
| 154 |
-
batch = next(iter(test_dataloader))
|
| 155 |
-
estimated_channels, ideal_channels = _forward_pass(batch, model)
|
| 156 |
-
|
| 157 |
-
var, val = name.split("_")
|
| 158 |
-
channels[int(val)] = {
|
| 159 |
-
"estimated_channel": estimated_channels[0],
|
| 160 |
-
"ideal_channel": ideal_channels[0]
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
return channels
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def train_epoch(
|
| 167 |
-
model: nn.Module,
|
| 168 |
-
optimizer: Optimizer,
|
| 169 |
-
loss_fn: Callable,
|
| 170 |
-
scheduler: ExponentialLR,
|
| 171 |
-
train_dataloader: DataLoader
|
| 172 |
-
) -> float:
|
| 173 |
-
"""
|
| 174 |
-
Train model for one epoch.
|
| 175 |
-
|
| 176 |
-
Performs a complete training iteration over the dataset, including:
|
| 177 |
-
- Forward pass through the model
|
| 178 |
-
- Loss calculation
|
| 179 |
-
- Backpropagation
|
| 180 |
-
- Parameter updates
|
| 181 |
-
- Learning rate scheduling
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
model: Model to train
|
| 185 |
-
optimizer: Optimizer for updating model parameters
|
| 186 |
-
loss_fn: Loss function for computing error
|
| 187 |
-
scheduler: Learning rate scheduler
|
| 188 |
-
train_dataloader: DataLoader containing training data
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
Average training loss for the epoch (adjusted for complex values)
|
| 192 |
-
|
| 193 |
-
Notes:
|
| 194 |
-
Loss is multiplied by 2 to account for complex-valued matrices being
|
| 195 |
-
represented as real-valued matrices of double size.
|
| 196 |
-
"""
|
| 197 |
-
train_loss = 0.0
|
| 198 |
-
model.train()
|
| 199 |
-
|
| 200 |
-
for batch in train_dataloader:
|
| 201 |
-
optimizer.zero_grad()
|
| 202 |
-
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
| 203 |
-
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 204 |
-
output.backward()
|
| 205 |
-
optimizer.step()
|
| 206 |
-
train_loss += (2 * output.item() * batch[0].size(0))
|
| 207 |
-
|
| 208 |
-
scheduler.step()
|
| 209 |
-
train_loss /= sum(len(batch[0]) for batch in train_dataloader)
|
| 210 |
-
return train_loss
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, ComplexTensor]:
|
| 214 |
-
"""
|
| 215 |
-
Perform forward pass through model.
|
| 216 |
-
|
| 217 |
-
Processes input data through the appropriate model based on its type,
|
| 218 |
-
handling different input requirements for different model architectures.
|
| 219 |
-
|
| 220 |
-
Args:
|
| 221 |
-
batch: Tuple containing (estimated_channel, ideal_channel, metadata)
|
| 222 |
-
model: Model to use for processing
|
| 223 |
-
|
| 224 |
-
Returns:
|
| 225 |
-
Tuple of (processed_estimated_channel, ideal_channel)
|
| 226 |
-
|
| 227 |
-
Raises:
|
| 228 |
-
ValueError: If model type is not recognized
|
| 229 |
-
"""
|
| 230 |
-
estimated_channel, ideal_channel, meta_data = batch
|
| 231 |
-
|
| 232 |
-
# All models now handle complex input directly
|
| 233 |
-
if hasattr(model, 'use_channel_adaptation') and model.use_channel_adaptation:
|
| 234 |
-
# AdaFortiTran uses meta_data for channel adaptation
|
| 235 |
-
estimated_channel = model(estimated_channel, meta_data)
|
| 236 |
-
else:
|
| 237 |
-
# Linear and FortiTran models don't use meta_data
|
| 238 |
-
estimated_channel = model(estimated_channel)
|
| 239 |
-
|
| 240 |
-
return estimated_channel, ideal_channel.to(model.device)
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def _compute_loss(
|
| 244 |
-
estimated_channel: ComplexTensor,
|
| 245 |
-
ideal_channel: ComplexTensor,
|
| 246 |
-
loss_fn: Callable
|
| 247 |
-
) -> torch.Tensor:
|
| 248 |
-
"""
|
| 249 |
-
Calculate loss between estimated and ideal channels.
|
| 250 |
-
|
| 251 |
-
Computes the loss between model output and ground truth using the specified
|
| 252 |
-
loss function, with appropriate handling of complex values.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
estimated_channel: Estimated channel from model
|
| 256 |
-
ideal_channel: Ground truth ideal channel
|
| 257 |
-
loss_fn: Loss function to compute error
|
| 258 |
-
|
| 259 |
-
Returns:
|
| 260 |
-
Computed loss value as a scalar tensor
|
| 261 |
-
"""
|
| 262 |
-
return loss_fn(
|
| 263 |
-
concat_complex_channel(estimated_channel),
|
| 264 |
-
concat_complex_channel(ideal_channel)
|
| 265 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main/trainer.py
CHANGED
|
@@ -11,9 +11,12 @@ import torch
|
|
| 11 |
from torch import nn, optim
|
| 12 |
from torch.utils.data import DataLoader
|
| 13 |
from torch.utils.tensorboard.writer import SummaryWriter
|
| 14 |
-
from typing import Dict, Tuple, Type, Union
|
| 15 |
import logging
|
| 16 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
from .parser import TrainingArguments
|
| 19 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
|
@@ -33,6 +36,291 @@ from src.config.schemas import SystemConfig, ModelConfig
|
|
| 33 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
class ModelTrainer:
|
| 37 |
"""Handles the training and evaluation of deep learning models.
|
| 38 |
|
|
@@ -59,6 +347,9 @@ class ModelTrainer:
|
|
| 59 |
val_loader: DataLoader for validation set (used for validation)
|
| 60 |
test_loaders: Dictionary of test set DataLoaders (used for testing)
|
| 61 |
logger: Logger instance for logging messages
|
|
|
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
|
| 64 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
|
@@ -86,13 +377,59 @@ class ModelTrainer:
|
|
| 86 |
self.logger = logging.getLogger(__name__)
|
| 87 |
|
| 88 |
self.model = self._initialize_model()
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
|
| 91 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
| 92 |
-
|
| 93 |
self.training_loss = nn.MSELoss()
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def _setup_tensorboard(self) -> SummaryWriter:
|
| 98 |
"""Set up TensorBoard logging.
|
|
@@ -134,26 +471,30 @@ class ModelTrainer:
|
|
| 134 |
return model
|
| 135 |
|
| 136 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
|
|
|
| 137 |
pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
|
|
|
|
| 138 |
# Training and validation dataloaders
|
| 139 |
-
train_dataset = MatDataset(
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
)
|
| 143 |
-
val_dataset = MatDataset(
|
| 144 |
-
self.args.val_set,
|
| 145 |
-
pilot_dims
|
| 146 |
-
)
|
| 147 |
train_loader = DataLoader(
|
| 148 |
train_dataset,
|
| 149 |
batch_size=self.args.batch_size,
|
| 150 |
-
shuffle=True
|
|
|
|
|
|
|
| 151 |
)
|
|
|
|
| 152 |
val_loader = DataLoader(
|
| 153 |
val_dataset,
|
| 154 |
batch_size=self.args.batch_size,
|
| 155 |
-
shuffle=
|
|
|
|
|
|
|
| 156 |
)
|
|
|
|
|
|
|
| 157 |
test_loaders = {
|
| 158 |
"DS": get_test_dataloaders(
|
| 159 |
self.args.test_set / "DS_test_set",
|
|
@@ -173,11 +514,7 @@ class ModelTrainer:
|
|
| 173 |
}
|
| 174 |
return train_loader, val_loader, test_loaders
|
| 175 |
|
| 176 |
-
def _log_test_results(
|
| 177 |
-
self,
|
| 178 |
-
epoch: int,
|
| 179 |
-
test_stats: Dict[str, Dict]
|
| 180 |
-
) -> None:
|
| 181 |
"""Log test results to TensorBoard.
|
| 182 |
|
| 183 |
Creates and logs visualizations for model performance across different test conditions.
|
|
@@ -198,7 +535,7 @@ class ModelTrainer:
|
|
| 198 |
)
|
| 199 |
|
| 200 |
# Plot error images
|
| 201 |
-
predicted_channels = self.
|
| 202 |
self.writer.add_figure(
|
| 203 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
| 204 |
figure=get_error_images(
|
|
@@ -208,15 +545,20 @@ class ModelTrainer:
|
|
| 208 |
)
|
| 209 |
)
|
| 210 |
|
| 211 |
-
def _run_tests(self, epoch: int) ->
|
| 212 |
"""Run tests and log results.
|
| 213 |
|
| 214 |
Evaluates the model on all test datasets and logs performance metrics and visualizations.
|
| 215 |
|
| 216 |
Args:
|
| 217 |
epoch: Current training epoch
|
|
|
|
|
|
|
|
|
|
| 218 |
"""
|
| 219 |
-
ds_stats
|
|
|
|
|
|
|
| 220 |
|
| 221 |
test_stats = {
|
| 222 |
"DS": ds_stats,
|
|
@@ -225,6 +567,8 @@ class ModelTrainer:
|
|
| 225 |
}
|
| 226 |
|
| 227 |
self._log_test_results(epoch, test_stats)
|
|
|
|
|
|
|
| 228 |
|
| 229 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
| 230 |
"""Log final training metrics and hyperparameters.
|
|
@@ -270,92 +614,84 @@ class ModelTrainer:
|
|
| 270 |
except Exception as e:
|
| 271 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
| 272 |
|
| 273 |
-
def
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
)
|
|
|
|
| 278 |
|
| 279 |
-
def
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
self.
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
self.
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
self.
|
| 312 |
-
num_samples = 0
|
| 313 |
-
with torch.no_grad():
|
| 314 |
-
for batch in eval_dataloader:
|
| 315 |
-
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 316 |
-
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 317 |
-
batch_size = batch[0].size(0)
|
| 318 |
-
val_loss += (2 * output.item() * batch_size)
|
| 319 |
-
num_samples += batch_size
|
| 320 |
-
val_loss /= num_samples
|
| 321 |
-
return val_loss
|
| 322 |
-
|
| 323 |
-
def _predict_channels(self, test_dataloaders):
|
| 324 |
-
channels = {}
|
| 325 |
-
sorted_loaders = sorted(
|
| 326 |
-
test_dataloaders,
|
| 327 |
-
key=lambda x: int(x[0].split("_")[1])
|
| 328 |
-
)
|
| 329 |
-
for name, test_dataloader in sorted_loaders:
|
| 330 |
-
with torch.no_grad():
|
| 331 |
-
batch = next(iter(test_dataloader))
|
| 332 |
-
estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
|
| 333 |
-
var, val = name.split("_")
|
| 334 |
-
channels[int(val)] = {
|
| 335 |
-
"estimated_channel": estimated_channels[0],
|
| 336 |
-
"ideal_channel": ideal_channels[0]
|
| 337 |
-
}
|
| 338 |
-
return channels
|
| 339 |
|
| 340 |
-
def
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
def
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
def train(self) -> None:
|
| 361 |
"""Execute the training loop.
|
|
@@ -366,21 +702,43 @@ class ModelTrainer:
|
|
| 366 |
- Early stopping when validation loss plateaus
|
| 367 |
- Logging final metrics and results
|
| 368 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
last_epoch = 0
|
| 370 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
|
|
|
| 371 |
for epoch in pbar:
|
| 372 |
last_epoch = epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# Training step
|
| 374 |
-
train_loss = self.
|
| 375 |
-
|
| 376 |
-
|
| 377 |
# Validation step
|
| 378 |
-
val_loss = self.
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
# Update progress bar with loss info
|
| 382 |
pbar.set_description(
|
| 383 |
-
f"Epoch {epoch + 1}/{self.args.max_epoch} -
|
|
|
|
|
|
|
| 384 |
|
| 385 |
if self.early_stopper.early_stop(val_loss):
|
| 386 |
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
|
@@ -391,8 +749,12 @@ class ModelTrainer:
|
|
| 391 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 392 |
pbar.write(message)
|
| 393 |
self._run_tests(epoch)
|
|
|
|
| 394 |
self._log_final_metrics(last_epoch)
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
|
| 398 |
def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
|
|
|
|
| 11 |
from torch import nn, optim
|
| 12 |
from torch.utils.data import DataLoader
|
| 13 |
from torch.utils.tensorboard.writer import SummaryWriter
|
| 14 |
+
from typing import Dict, Tuple, Type, Union, Optional, List, Protocol
|
| 15 |
import logging
|
| 16 |
from tqdm import tqdm
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
|
| 21 |
from .parser import TrainingArguments
|
| 22 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
|
|
|
| 36 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
| 37 |
|
| 38 |
|
| 39 |
+
@dataclass
|
| 40 |
+
class TrainingMetrics:
|
| 41 |
+
"""Container for training metrics."""
|
| 42 |
+
train_loss: float
|
| 43 |
+
val_loss: float
|
| 44 |
+
epoch: int
|
| 45 |
+
learning_rate: float
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class TestResults:
|
| 50 |
+
"""Container for test results."""
|
| 51 |
+
ds_stats: Dict[int, float]
|
| 52 |
+
mds_stats: Dict[int, float]
|
| 53 |
+
snr_stats: Dict[int, float]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Callback(ABC):
|
| 57 |
+
"""Base class for training callbacks."""
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
| 61 |
+
"""Called at the beginning of each epoch."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
| 66 |
+
"""Called at the end of each epoch."""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def on_training_begin(self) -> None:
|
| 71 |
+
"""Called at the beginning of training."""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def on_training_end(self) -> None:
|
| 76 |
+
"""Called at the end of training."""
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class CheckpointCallback(Callback):
|
| 81 |
+
"""Callback for saving model checkpoints."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, save_dir: Path, save_best_only: bool = True,
|
| 84 |
+
save_every_n_epochs: Optional[int] = None):
|
| 85 |
+
self.save_dir = save_dir
|
| 86 |
+
self.save_best_only = save_best_only
|
| 87 |
+
self.save_every_n_epochs = save_every_n_epochs
|
| 88 |
+
self.best_val_loss = float('inf')
|
| 89 |
+
self.trainer = None
|
| 90 |
+
|
| 91 |
+
def set_trainer(self, trainer: 'ModelTrainer') -> None:
|
| 92 |
+
"""Set the trainer reference."""
|
| 93 |
+
self.trainer = trainer
|
| 94 |
+
|
| 95 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
| 99 |
+
if self.trainer is None:
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
# Save best model
|
| 103 |
+
if self.save_best_only and metrics.val_loss < self.best_val_loss:
|
| 104 |
+
self.best_val_loss = metrics.val_loss
|
| 105 |
+
self.trainer.save_checkpoint(
|
| 106 |
+
epoch, metrics,
|
| 107 |
+
checkpoint_dir=self.save_dir / "best"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Save every N epochs
|
| 111 |
+
if (self.save_every_n_epochs is not None and
|
| 112 |
+
(epoch + 1) % self.save_every_n_epochs == 0):
|
| 113 |
+
self.trainer.save_checkpoint(
|
| 114 |
+
epoch, metrics,
|
| 115 |
+
checkpoint_dir=self.save_dir / "periodic"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def on_training_begin(self) -> None:
|
| 119 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
def on_training_end(self) -> None:
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TensorBoardCallback(Callback):
|
| 126 |
+
"""Callback for TensorBoard logging."""
|
| 127 |
+
|
| 128 |
+
def __init__(self, writer: SummaryWriter):
|
| 129 |
+
self.writer = writer
|
| 130 |
+
|
| 131 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
| 135 |
+
self.writer.add_scalar('Loss/Train', metrics.train_loss, metrics.epoch + 1)
|
| 136 |
+
self.writer.add_scalar('Loss/Val', metrics.val_loss, metrics.epoch + 1)
|
| 137 |
+
self.writer.add_scalar('Learning_Rate', metrics.learning_rate, metrics.epoch + 1)
|
| 138 |
+
|
| 139 |
+
def on_training_begin(self) -> None:
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
def on_training_end(self) -> None:
|
| 143 |
+
self.writer.close()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TrainingLoop:
|
| 147 |
+
"""Handles the core training loop logic."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, model: ModelType, optimizer: optim.Optimizer,
|
| 150 |
+
scheduler: optim.lr_scheduler.LRScheduler,
|
| 151 |
+
loss_fn: nn.Module, device: torch.device, scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
| 152 |
+
gradient_clip_val: Optional[float] = None):
|
| 153 |
+
self.model = model
|
| 154 |
+
self.optimizer = optimizer
|
| 155 |
+
self.scheduler = scheduler
|
| 156 |
+
self.loss_fn = loss_fn
|
| 157 |
+
self.device = device
|
| 158 |
+
self.scaler = scaler
|
| 159 |
+
self.gradient_clip_val = gradient_clip_val
|
| 160 |
+
|
| 161 |
+
def _compute_loss(self, estimated_channel: torch.Tensor,
|
| 162 |
+
ideal_channel: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""Compute loss between estimated and ideal channels."""
|
| 164 |
+
return self.loss_fn(
|
| 165 |
+
concat_complex_channel(estimated_channel),
|
| 166 |
+
concat_complex_channel(ideal_channel)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
|
| 170 |
+
model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 171 |
+
"""Perform forward pass through the model."""
|
| 172 |
+
estimated_channel, ideal_channel, meta_data = batch
|
| 173 |
+
|
| 174 |
+
# All models now handle complex input directly
|
| 175 |
+
if isinstance(model, AdaFortiTranEstimator):
|
| 176 |
+
# AdaFortiTran uses meta_data for channel adaptation
|
| 177 |
+
estimated_channel = model(estimated_channel, meta_data)
|
| 178 |
+
else:
|
| 179 |
+
# Linear and FortiTran models don't use meta_data
|
| 180 |
+
estimated_channel = model(estimated_channel)
|
| 181 |
+
|
| 182 |
+
return estimated_channel, ideal_channel.to(model.device)
|
| 183 |
+
|
| 184 |
+
def train_epoch(self, train_loader: DataLoader) -> float:
|
| 185 |
+
"""Train for one epoch."""
|
| 186 |
+
train_loss = 0.0
|
| 187 |
+
self.model.train()
|
| 188 |
+
num_samples = 0
|
| 189 |
+
|
| 190 |
+
for batch in train_loader:
|
| 191 |
+
self.optimizer.zero_grad()
|
| 192 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 193 |
+
|
| 194 |
+
if self.scaler:
|
| 195 |
+
with torch.cuda.amp.autocast():
|
| 196 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
| 197 |
+
self.scaler.scale(loss).backward()
|
| 198 |
+
|
| 199 |
+
# Gradient clipping
|
| 200 |
+
if self.gradient_clip_val:
|
| 201 |
+
self.scaler.unscale_(self.optimizer)
|
| 202 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
|
| 203 |
+
|
| 204 |
+
self.scaler.step(self.optimizer)
|
| 205 |
+
self.scaler.update()
|
| 206 |
+
else:
|
| 207 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
| 208 |
+
loss.backward()
|
| 209 |
+
|
| 210 |
+
# Gradient clipping
|
| 211 |
+
if self.gradient_clip_val:
|
| 212 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
|
| 213 |
+
|
| 214 |
+
self.optimizer.step()
|
| 215 |
+
|
| 216 |
+
batch_size = batch[0].size(0)
|
| 217 |
+
train_loss += (2 * loss.item() * batch_size)
|
| 218 |
+
num_samples += batch_size
|
| 219 |
+
|
| 220 |
+
self.scheduler.step()
|
| 221 |
+
return train_loss / num_samples
|
| 222 |
+
|
| 223 |
+
def evaluate(self, eval_loader: DataLoader) -> float:
|
| 224 |
+
"""Evaluate the model."""
|
| 225 |
+
val_loss = 0.0
|
| 226 |
+
self.model.eval()
|
| 227 |
+
num_samples = 0
|
| 228 |
+
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for batch in eval_loader:
|
| 231 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 232 |
+
|
| 233 |
+
if self.scaler:
|
| 234 |
+
with torch.cuda.amp.autocast():
|
| 235 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
| 236 |
+
else:
|
| 237 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
| 238 |
+
|
| 239 |
+
batch_size = batch[0].size(0)
|
| 240 |
+
val_loss += (2 * loss.item() * batch_size)
|
| 241 |
+
num_samples += batch_size
|
| 242 |
+
|
| 243 |
+
return val_loss / num_samples
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ModelEvaluator:
|
| 247 |
+
"""Handles model evaluation and testing."""
|
| 248 |
+
|
| 249 |
+
def __init__(self, model: ModelType, device: torch.device, logger: logging.Logger):
|
| 250 |
+
self.model = model
|
| 251 |
+
self.device = device
|
| 252 |
+
self.logger = logger
|
| 253 |
+
|
| 254 |
+
def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
|
| 255 |
+
model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 256 |
+
"""Perform forward pass through the model."""
|
| 257 |
+
estimated_channel, ideal_channel, meta_data = batch
|
| 258 |
+
|
| 259 |
+
if isinstance(model, AdaFortiTranEstimator):
|
| 260 |
+
estimated_channel = model(estimated_channel, meta_data)
|
| 261 |
+
else:
|
| 262 |
+
estimated_channel = model(estimated_channel)
|
| 263 |
+
|
| 264 |
+
return estimated_channel, ideal_channel.to(model.device)
|
| 265 |
+
|
| 266 |
+
def predict_channels(self, test_dataloaders: List[Tuple[str, DataLoader]]) -> Dict[int, Dict]:
|
| 267 |
+
"""Predict channels for visualization."""
|
| 268 |
+
channels = {}
|
| 269 |
+
sorted_loaders = sorted(
|
| 270 |
+
test_dataloaders,
|
| 271 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
for name, test_dataloader in sorted_loaders:
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
batch = next(iter(test_dataloader))
|
| 277 |
+
estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
|
| 278 |
+
|
| 279 |
+
var, val = name.split("_")
|
| 280 |
+
channels[int(val)] = {
|
| 281 |
+
"estimated_channel": estimated_channels[0],
|
| 282 |
+
"ideal_channel": ideal_channels[0]
|
| 283 |
+
}
|
| 284 |
+
return channels
|
| 285 |
+
|
| 286 |
+
def get_test_stats(self, test_dataloaders: List[Tuple[str, DataLoader]],
|
| 287 |
+
loss_fn: nn.Module) -> Dict[int, float]:
|
| 288 |
+
"""Get test statistics for a set of dataloaders."""
|
| 289 |
+
stats = {}
|
| 290 |
+
sorted_loaders = sorted(
|
| 291 |
+
test_dataloaders,
|
| 292 |
+
key=lambda x: int(x[0].split("_")[1])
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
for name, test_dataloader in sorted_loaders:
|
| 296 |
+
var, val = name.split("_")
|
| 297 |
+
test_loss = self._evaluate_dataloader(test_dataloader, loss_fn)
|
| 298 |
+
db_error = to_db(test_loss)
|
| 299 |
+
self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
| 300 |
+
stats[int(val)] = db_error
|
| 301 |
+
return stats
|
| 302 |
+
|
| 303 |
+
def _evaluate_dataloader(self, dataloader: DataLoader, loss_fn: nn.Module) -> float:
|
| 304 |
+
"""Evaluate a single dataloader."""
|
| 305 |
+
total_loss = 0.0
|
| 306 |
+
num_samples = 0
|
| 307 |
+
self.model.eval()
|
| 308 |
+
|
| 309 |
+
with torch.no_grad():
|
| 310 |
+
for batch in dataloader:
|
| 311 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 312 |
+
loss = loss_fn(
|
| 313 |
+
concat_complex_channel(estimated_channel),
|
| 314 |
+
concat_complex_channel(ideal_channel)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
batch_size = batch[0].size(0)
|
| 318 |
+
total_loss += (2 * loss.item() * batch_size)
|
| 319 |
+
num_samples += batch_size
|
| 320 |
+
|
| 321 |
+
return total_loss / num_samples
|
| 322 |
+
|
| 323 |
+
|
| 324 |
class ModelTrainer:
|
| 325 |
"""Handles the training and evaluation of deep learning models.
|
| 326 |
|
|
|
|
| 347 |
val_loader: DataLoader for validation set (used for validation)
|
| 348 |
test_loaders: Dictionary of test set DataLoaders (used for testing)
|
| 349 |
logger: Logger instance for logging messages
|
| 350 |
+
training_loop: TrainingLoop instance for core training logic
|
| 351 |
+
evaluator: ModelEvaluator instance for evaluation logic
|
| 352 |
+
callbacks: List of training callbacks
|
| 353 |
"""
|
| 354 |
|
| 355 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
|
|
|
| 377 |
self.logger = logging.getLogger(__name__)
|
| 378 |
|
| 379 |
self.model = self._initialize_model()
|
| 380 |
+
|
| 381 |
+
# Initialize optimizer with weight decay
|
| 382 |
+
self.optimizer = optim.Adam(
|
| 383 |
+
self.model.parameters(),
|
| 384 |
+
lr=args.lr,
|
| 385 |
+
weight_decay=args.weight_decay
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
|
| 389 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
|
|
|
| 390 |
self.training_loss = nn.MSELoss()
|
| 391 |
|
| 392 |
+
# Initialize mixed precision training if requested
|
| 393 |
+
self.scaler = None
|
| 394 |
+
if args.use_mixed_precision and self.device.type == 'cuda':
|
| 395 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
| 396 |
+
self.logger.info("Mixed precision training enabled")
|
| 397 |
+
|
| 398 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
| 399 |
+
|
| 400 |
+
# Initialize components
|
| 401 |
+
self.training_loop = TrainingLoop(
|
| 402 |
+
self.model, self.optimizer, self.scheduler, self.training_loss,
|
| 403 |
+
self.device, self.scaler, self.args.gradient_clip_val
|
| 404 |
+
)
|
| 405 |
+
self.evaluator = ModelEvaluator(self.model, self.device, self.logger)
|
| 406 |
+
|
| 407 |
+
# Initialize callbacks
|
| 408 |
+
self.callbacks = self._setup_callbacks()
|
| 409 |
+
|
| 410 |
+
# Resume from checkpoint if specified
|
| 411 |
+
if args.resume_from_checkpoint is not None:
|
| 412 |
+
self._resume_from_checkpoint(args.resume_from_checkpoint)
|
| 413 |
+
|
| 414 |
+
def _setup_callbacks(self) -> List[Callback]:
|
| 415 |
+
"""Set up training callbacks."""
|
| 416 |
+
callbacks = []
|
| 417 |
+
|
| 418 |
+
# TensorBoard callback
|
| 419 |
+
callbacks.append(TensorBoardCallback(self.writer))
|
| 420 |
+
|
| 421 |
+
# Checkpoint callback (only if checkpointing is enabled)
|
| 422 |
+
if self.args.save_checkpoints:
|
| 423 |
+
checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
|
| 424 |
+
checkpoint_callback = CheckpointCallback(
|
| 425 |
+
save_dir=checkpoint_dir,
|
| 426 |
+
save_best_only=self.args.save_best_only,
|
| 427 |
+
save_every_n_epochs=self.args.save_every_n_epochs
|
| 428 |
+
)
|
| 429 |
+
checkpoint_callback.set_trainer(self)
|
| 430 |
+
callbacks.append(checkpoint_callback)
|
| 431 |
+
|
| 432 |
+
return callbacks
|
| 433 |
|
| 434 |
def _setup_tensorboard(self) -> SummaryWriter:
|
| 435 |
"""Set up TensorBoard logging.
|
|
|
|
| 471 |
return model
|
| 472 |
|
| 473 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
| 474 |
+
"""Get training, validation, and test dataloaders."""
|
| 475 |
pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
|
| 476 |
+
|
| 477 |
# Training and validation dataloaders
|
| 478 |
+
train_dataset = MatDataset(self.args.train_set, pilot_dims)
|
| 479 |
+
val_dataset = MatDataset(self.args.val_set, pilot_dims)
|
| 480 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
train_loader = DataLoader(
|
| 482 |
train_dataset,
|
| 483 |
batch_size=self.args.batch_size,
|
| 484 |
+
shuffle=True,
|
| 485 |
+
num_workers=self.args.num_workers,
|
| 486 |
+
pin_memory=self.args.pin_memory and self.device.type == 'cuda'
|
| 487 |
)
|
| 488 |
+
|
| 489 |
val_loader = DataLoader(
|
| 490 |
val_dataset,
|
| 491 |
batch_size=self.args.batch_size,
|
| 492 |
+
shuffle=False, # No need to shuffle validation data
|
| 493 |
+
num_workers=self.args.num_workers,
|
| 494 |
+
pin_memory=self.args.pin_memory and self.device.type == 'cuda'
|
| 495 |
)
|
| 496 |
+
|
| 497 |
+
# Test dataloaders
|
| 498 |
test_loaders = {
|
| 499 |
"DS": get_test_dataloaders(
|
| 500 |
self.args.test_set / "DS_test_set",
|
|
|
|
| 514 |
}
|
| 515 |
return train_loader, val_loader, test_loaders
|
| 516 |
|
| 517 |
+
def _log_test_results(self, epoch: int, test_stats: Dict[str, Dict]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
"""Log test results to TensorBoard.
|
| 519 |
|
| 520 |
Creates and logs visualizations for model performance across different test conditions.
|
|
|
|
| 535 |
)
|
| 536 |
|
| 537 |
# Plot error images
|
| 538 |
+
predicted_channels = self.evaluator.predict_channels(self.test_loaders[key])
|
| 539 |
self.writer.add_figure(
|
| 540 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
| 541 |
figure=get_error_images(
|
|
|
|
| 545 |
)
|
| 546 |
)
|
| 547 |
|
| 548 |
+
def _run_tests(self, epoch: int) -> TestResults:
|
| 549 |
"""Run tests and log results.
|
| 550 |
|
| 551 |
Evaluates the model on all test datasets and logs performance metrics and visualizations.
|
| 552 |
|
| 553 |
Args:
|
| 554 |
epoch: Current training epoch
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
TestResults containing all test statistics
|
| 558 |
"""
|
| 559 |
+
ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
|
| 560 |
+
mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
|
| 561 |
+
snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
|
| 562 |
|
| 563 |
test_stats = {
|
| 564 |
"DS": ds_stats,
|
|
|
|
| 567 |
}
|
| 568 |
|
| 569 |
self._log_test_results(epoch, test_stats)
|
| 570 |
+
|
| 571 |
+
return TestResults(ds_stats, mds_stats, snr_stats)
|
| 572 |
|
| 573 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
| 574 |
"""Log final training metrics and hyperparameters.
|
|
|
|
| 614 |
except Exception as e:
|
| 615 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
| 616 |
|
| 617 |
+
def _get_all_test_stats(self) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float]]:
|
| 618 |
+
"""Get all test statistics."""
|
| 619 |
+
ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
|
| 620 |
+
mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
|
| 621 |
+
snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
|
| 622 |
+
return ds_stats, mds_stats, snr_stats
|
| 623 |
|
| 624 |
+
def save_checkpoint(self, epoch: int, metrics: TrainingMetrics,
|
| 625 |
+
checkpoint_dir: Optional[Path] = None) -> None:
|
| 626 |
+
"""Save model checkpoint.
|
| 627 |
|
| 628 |
+
Args:
|
| 629 |
+
epoch: Current epoch number
|
| 630 |
+
metrics: Current training metrics
|
| 631 |
+
checkpoint_dir: Directory to save checkpoint (defaults to tensorboard log dir)
|
| 632 |
+
"""
|
| 633 |
+
if checkpoint_dir is None:
|
| 634 |
+
checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
|
| 635 |
+
|
| 636 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 637 |
+
checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
|
| 638 |
+
|
| 639 |
+
checkpoint = {
|
| 640 |
+
'epoch': epoch,
|
| 641 |
+
'model_state_dict': self.model.state_dict(),
|
| 642 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 643 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
| 644 |
+
'train_loss': metrics.train_loss,
|
| 645 |
+
'val_loss': metrics.val_loss,
|
| 646 |
+
'learning_rate': metrics.learning_rate,
|
| 647 |
+
'system_config': self.system_config,
|
| 648 |
+
'model_config': self.model_config,
|
| 649 |
+
'args': self.args
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
# Save scaler state if using mixed precision
|
| 653 |
+
if self.scaler:
|
| 654 |
+
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
| 655 |
+
|
| 656 |
+
torch.save(checkpoint, checkpoint_path)
|
| 657 |
+
self.logger.info(f"Checkpoint saved to {checkpoint_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
+
def load_checkpoint(self, checkpoint_path: Path) -> int:
|
| 660 |
+
"""Load model checkpoint.
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
checkpoint_path: Path to checkpoint file
|
| 664 |
+
|
| 665 |
+
Returns:
|
| 666 |
+
Epoch number of loaded checkpoint
|
| 667 |
+
"""
|
| 668 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 669 |
+
|
| 670 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 671 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 672 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 673 |
+
|
| 674 |
+
# Load scaler state if it exists
|
| 675 |
+
if self.scaler and 'scaler_state_dict' in checkpoint:
|
| 676 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 677 |
+
|
| 678 |
+
self.logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
|
| 679 |
+
return checkpoint['epoch']
|
| 680 |
|
| 681 |
+
def _resume_from_checkpoint(self, checkpoint_path: Path) -> None:
|
| 682 |
+
"""Resume training from a checkpoint.
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
checkpoint_path: Path to checkpoint file
|
| 686 |
+
"""
|
| 687 |
+
start_epoch = self.load_checkpoint(checkpoint_path)
|
| 688 |
+
self.logger.info(f"Resuming training from epoch {start_epoch}")
|
| 689 |
+
|
| 690 |
+
# Update the early stopper with the best loss from checkpoint
|
| 691 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 692 |
+
if 'val_loss' in checkpoint:
|
| 693 |
+
self.early_stopper.min_loss = checkpoint['val_loss']
|
| 694 |
+
self.logger.info(f"Early stopper initialized with validation loss: {checkpoint['val_loss']:.4f}")
|
| 695 |
|
| 696 |
def train(self) -> None:
|
| 697 |
"""Execute the training loop.
|
|
|
|
| 702 |
- Early stopping when validation loss plateaus
|
| 703 |
- Logging final metrics and results
|
| 704 |
"""
|
| 705 |
+
# Notify callbacks that training is beginning
|
| 706 |
+
for callback in self.callbacks:
|
| 707 |
+
callback.on_training_begin()
|
| 708 |
+
|
| 709 |
last_epoch = 0
|
| 710 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
| 711 |
+
|
| 712 |
for epoch in pbar:
|
| 713 |
last_epoch = epoch
|
| 714 |
+
|
| 715 |
+
# Notify callbacks that epoch is beginning
|
| 716 |
+
for callback in self.callbacks:
|
| 717 |
+
callback.on_epoch_begin(epoch)
|
| 718 |
+
|
| 719 |
# Training step
|
| 720 |
+
train_loss = self.training_loop.train_epoch(self.train_loader)
|
| 721 |
+
|
|
|
|
| 722 |
# Validation step
|
| 723 |
+
val_loss = self.training_loop.evaluate(self.val_loader)
|
| 724 |
+
|
| 725 |
+
# Create metrics object
|
| 726 |
+
metrics = TrainingMetrics(
|
| 727 |
+
train_loss=train_loss,
|
| 728 |
+
val_loss=val_loss,
|
| 729 |
+
epoch=epoch,
|
| 730 |
+
learning_rate=self.optimizer.param_groups[0]['lr']
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Notify callbacks that epoch has ended
|
| 734 |
+
for callback in self.callbacks:
|
| 735 |
+
callback.on_epoch_end(epoch, metrics)
|
| 736 |
|
| 737 |
# Update progress bar with loss info
|
| 738 |
pbar.set_description(
|
| 739 |
+
f"Epoch {epoch + 1}/{self.args.max_epoch} - "
|
| 740 |
+
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
|
| 741 |
+
)
|
| 742 |
|
| 743 |
if self.early_stopper.early_stop(val_loss):
|
| 744 |
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
|
|
|
| 749 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 750 |
pbar.write(message)
|
| 751 |
self._run_tests(epoch)
|
| 752 |
+
|
| 753 |
self._log_final_metrics(last_epoch)
|
| 754 |
+
|
| 755 |
+
# Notify callbacks that training has ended
|
| 756 |
+
for callback in self.callbacks:
|
| 757 |
+
callback.on_training_end()
|
| 758 |
|
| 759 |
|
| 760 |
def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
|
src/utils.py
CHANGED
|
@@ -180,24 +180,7 @@ def concat_complex_channel(channel_matrix):
|
|
| 180 |
return cat_channel_m
|
| 181 |
|
| 182 |
|
| 183 |
-
def inverse_concat_complex_channel(channel_matrix: torch.Tensor) -> torch.Tensor:
|
| 184 |
-
"""
|
| 185 |
-
Reconstruct complex channel matrix from concatenated real matrix.
|
| 186 |
-
|
| 187 |
-
Reverses the operation performed by concat_complex_channel by
|
| 188 |
-
splitting the tensor and combining the parts into a complex tensor.
|
| 189 |
-
|
| 190 |
-
Args:
|
| 191 |
-
channel_matrix: Real-valued matrix of shape (B, F, 2*T)
|
| 192 |
|
| 193 |
-
Returns:
|
| 194 |
-
Complex matrix of shape (B, F, T)
|
| 195 |
-
"""
|
| 196 |
-
split_idx = channel_matrix.shape[-1] // 2
|
| 197 |
-
return torch.complex(
|
| 198 |
-
channel_matrix[:, :split_idx],
|
| 199 |
-
channel_matrix[:, split_idx:]
|
| 200 |
-
)
|
| 201 |
|
| 202 |
|
| 203 |
def get_test_stats_plot(x_name, stats, methods, show=False):
|
|
|
|
| 180 |
return cat_channel_m
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
def get_test_stats_plot(x_name, stats, methods, show=False):
|