BerkIGuler commited on
Commit
301f5ca
·
1 Parent(s): c871fdf

Remove .gitattributes, .huggingfaceignore, and README_original.md files to clean up the repository and streamline documentation.

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -5
  2. .huggingfaceignore +0 -34
  3. README.md +394 -28
  4. README_original.md +0 -438
.gitattributes DELETED
@@ -1,5 +0,0 @@
1
- *.mat filter=lfs diff=lfs merge=lfs -text
2
- *.pth filter=lfs diff=lfs merge=lfs -text
3
- *.ckpt filter=lfs diff=lfs merge=lfs -text
4
- *.pt filter=lfs diff=lfs merge=lfs -text
5
- *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
.huggingfaceignore DELETED
@@ -1,34 +0,0 @@
1
- # Ignore large data files during upload
2
- data/train/
3
- data/val/
4
- data/test/
5
-
6
- # Ignore model checkpoints and logs
7
- *.ckpt
8
- *.pth
9
- *.pt
10
- logs/
11
- runs/
12
- checkpoints/
13
-
14
- # Ignore temporary files
15
- __pycache__/
16
- *.pyc
17
- *.pyo
18
- *.pyd
19
- .Python
20
- *.so
21
- .DS_Store
22
- Thumbs.db
23
-
24
- # Ignore IDE files
25
- .vscode/
26
- .idea/
27
- *.swp
28
- *.swo
29
-
30
- # Ignore environment files
31
- .env
32
- .venv/
33
- venv/
34
- env/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,35 +1,23 @@
1
- ---
2
- language:
3
- - en
4
- tags:
5
- - pytorch
6
- - transformer
7
- - channel-estimation
8
- - ofdm
9
- - wireless
10
- - adaptive
11
- license: mit
12
- datasets:
13
- - custom
14
- metrics:
15
- - mse
16
- ---
17
-
18
  # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
19
 
20
- ## Model Description
 
 
21
 
22
- AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
23
 
24
- ## Key Features
 
 
25
 
 
26
  - **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
27
  - **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
28
  - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
29
  - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
30
  - **🚀 Production Ready**: Comprehensive training pipeline with advanced features
31
 
32
- ## Architecture
33
 
34
  The project implements three model variants:
35
 
@@ -37,23 +25,399 @@ The project implements three model variants:
37
  2. **FortiTran**: Fixed transformer-based channel estimator
38
  3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
39
 
40
- ## Usage
 
 
 
 
 
 
 
 
41
 
42
  ### Installation
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ```bash
45
- pip install -r requirements.txt
 
 
 
 
 
 
 
 
46
  ```
47
 
48
- ### Training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  ```bash
51
- python src/main.py --model_name adafortitran --system_config_path config/system_config.yaml --model_config_path config/adafortitran.yaml --train_set data/train --val_set data/val --test_set data/test --exp_id my_experiment
52
  ```
53
 
54
- ## Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- If you use this model in your research, please cite:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  ```bibtex
59
  @misc{guler2025adafortitranadaptivetransformermodel,
@@ -67,6 +431,8 @@ If you use this model in your research, please cite:
67
  }
68
  ```
69
 
70
- ## License
71
 
72
  This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
2
 
3
+ [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
4
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
5
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
6
 
7
+ Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
8
 
9
+ ## 📖 Overview
10
+
11
+ AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
12
 
13
+ ### Key Features
14
  - **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
15
  - **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
16
  - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
17
  - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
18
  - **🚀 Production Ready**: Comprehensive training pipeline with advanced features
19
 
20
+ ## 🏗️ Architecture
21
 
22
  The project implements three model variants:
23
 
 
25
  2. **FortiTran**: Fixed transformer-based channel estimator
26
  3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
27
 
28
+ ### Model Comparison
29
+
30
+ | Model | Channel Adaptation | Complexity | Performance |
31
+ |-------|-------------------|------------|-------------|
32
+ | Linear | ❌ | Low | Baseline |
33
+ | FortiTran | ❌ | Medium | Good |
34
+ | AdaFortiTran | ✅ | High | **Best** |
35
+
36
+ ## 🚀 Quick Start
37
 
38
  ### Installation
39
 
40
+ 1. **Clone the repository**:
41
+ ```bash
42
+ git clone https://github.com/your-username/AdaFortiTran.git
43
+ cd AdaFortiTran
44
+ ```
45
+
46
+ 2. **Install dependencies**:
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ 3. **Verify installation**:
52
+ ```bash
53
+ python -c "import torch; print(f'PyTorch {torch.__version__}')"
54
+ ```
55
+
56
+ ### Basic Training
57
+
58
+ Train an AdaFortiTran model with default settings:
59
+
60
+ ```bash
61
+ python src/main.py \
62
+ --model_name adafortitran \
63
+ --system_config_path config/system_config.yaml \
64
+ --model_config_path config/adafortitran.yaml \
65
+ --train_set data/train \
66
+ --val_set data/val \
67
+ --test_set data/test \
68
+ --exp_id my_experiment
69
+ ```
70
+
71
+ ### Advanced Training
72
+
73
+ Use all available features for optimal performance:
74
+
75
+ ```bash
76
+ python src/main.py \
77
+ --model_name adafortitran \
78
+ --system_config_path config/system_config.yaml \
79
+ --model_config_path config/adafortitran.yaml \
80
+ --train_set data/train \
81
+ --val_set data/val \
82
+ --test_set data/test \
83
+ --exp_id advanced_experiment \
84
+ --batch_size 128 \
85
+ --lr 5e-4 \
86
+ --max_epoch 100 \
87
+ --patience 10 \
88
+ --weight_decay 1e-4 \
89
+ --gradient_clip_val 1.0 \
90
+ --use_mixed_precision \
91
+ --save_every_n_epochs 5 \
92
+ --num_workers 8 \
93
+ --test_every_n 5
94
+ ```
95
+
96
+ ## 📁 Project Structure
97
+
98
+ ```
99
+ AdaFortiTran/
100
+ ├── config/ # Configuration files
101
+ │ ├── system_config.yaml # OFDM system parameters
102
+ │ ├── adafortitran.yaml # AdaFortiTran model config
103
+ │ ├── fortitran.yaml # FortiTran model config
104
+ │ └── linear.yaml # Linear model config
105
+ ├── data/ # Dataset directory
106
+ │ ├── train/ # Training data
107
+ │ ├── val/ # Validation data
108
+ │ └── test/ # Test data (DS, MDS, SNR sets)
109
+ ├── src/ # Source code
110
+ │ ├── main/ # Training pipeline
111
+ │ │ ├── trainer.py # Enhanced ModelTrainer
112
+ │ │ └── parser.py # Command-line argument parser
113
+ │ ├── models/ # Model implementations
114
+ │ │ ├── adafortitran.py # AdaFortiTran model
115
+ │ │ ├── fortitran.py # FortiTran model
116
+ │ │ ├── linear.py # Linear model
117
+ │ │ └── blocks/ # Model building blocks
118
+ │ ├── data/ # Data loading
119
+ │ │ └── dataset.py # Dataset and DataLoader classes
120
+ │ ├── config/ # Configuration management
121
+ │ │ ├── config_loader.py # YAML configuration loader
122
+ │ │ └── schemas.py # Pydantic validation schemas
123
+ │ └── utils.py # Utility functions
124
+ ├── requirements.txt # Python dependencies
125
+ ├── README.md # This file
126
+ ```
127
+
128
+ ## ⚙️ Configuration
129
+
130
+ ### System Configuration (`config/system_config.yaml`)
131
+
132
+ Defines OFDM system parameters:
133
+
134
+ ```yaml
135
+ ofdm:
136
+ num_scs: 120 # Number of subcarriers
137
+ num_symbols: 14 # Number of OFDM symbols
138
+
139
+ pilot:
140
+ num_scs: 12 # Number of pilot subcarriers
141
+ num_symbols: 2 # Number of pilot symbols
142
+ ```
143
+
144
+ ### Model Configuration (`config/adafortitran.yaml`)
145
+
146
+ Defines model architecture parameters:
147
+
148
+ ```yaml
149
+ model_type: 'adafortitran'
150
+ patch_size: [3, 2] # Patch dimensions
151
+ num_layers: 6 # Transformer layers
152
+ model_dim: 128 # Model dimension
153
+ num_head: 4 # Attention heads
154
+ activation: 'gelu' # Activation function
155
+ dropout: 0.1 # Dropout rate
156
+ max_seq_len: 512 # Maximum sequence length
157
+ pos_encoding_type: 'learnable' # Positional encoding
158
+ channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
159
+ adaptive_token_length: 6 # Adaptive token length
160
+ ```
161
+
162
+ ## 🎯 Training Features
163
+
164
+ ### Advanced Training Options
165
+
166
+ | Feature | Description | Default |
167
+ |---------|-------------|---------|
168
+ | `--use_mixed_precision` | Enable mixed precision training | False |
169
+ | `--gradient_clip_val` | Gradient clipping value | None |
170
+ | `--weight_decay` | Weight decay for optimizer | 0.0 |
171
+ | `--save_checkpoints` | Enable model checkpointing | True |
172
+ | `--save_best_only` | Save only best model | True |
173
+ | `--resume_from_checkpoint` | Resume from checkpoint | None |
174
+ | `--num_workers` | Data loading workers | 4 |
175
+ | `--pin_memory` | Pin memory for GPU | True |
176
+
177
+ ### Callback System
178
+
179
+ The training pipeline includes an extensible callback system:
180
+
181
+ - **TensorBoard Logging**: Automatic metric tracking and visualization
182
+ - **Checkpoint Management**: Flexible checkpoint saving strategies
183
+ - **Custom Callbacks**: Easy to add new logging or monitoring systems
184
+
185
+ ### Performance Optimizations
186
+
187
+ - **Mixed Precision Training**: Faster training on modern GPUs
188
+ - **Optimized Data Loading**: Configurable workers and memory pinning
189
+ - **Gradient Clipping**: Stable training with configurable clipping
190
+ - **Early Stopping**: Automatic training termination on plateau
191
+
192
+ ## 📊 Dataset Format
193
+
194
+ ### Expected File Structure
195
+
196
+ ```
197
+ data/
198
+ ├── train/
199
+ │ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
200
+ │ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
201
+ │ └── ...
202
+ ├── val/
203
+ │ └── ...
204
+ └── test/
205
+ ├── DS_test_set/ # Delay Spread tests
206
+ │ ├── DS_50/
207
+ │ ├── DS_100/
208
+ │ └── ...
209
+ ├── SNR_test_set/ # SNR tests
210
+ │ ├── SNR_10/
211
+ │ ├── SNR_20/
212
+ │ └── ...
213
+ └── MDS_test_set/ # Multi-Doppler tests
214
+ ├── DOP_200/
215
+ ├── DOP_400/
216
+ └── ...
217
+ ```
218
+
219
+ ### File Naming Convention
220
+
221
+ Files must follow the pattern:
222
+ ```
223
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
224
+ ```
225
+
226
+ Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
227
+
228
+ ### Data Format
229
+
230
+ Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
231
+ - `H[:, :, 0]`: Ground truth channel (complex values)
232
+ - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
233
+ - `H[:, :, 2]`: Reserved for future use
234
+
235
+ ## 🔧 Usage Examples
236
+
237
+ ### Training Different Models
238
+
239
+ **Linear Estimator**:
240
+ ```bash
241
+ python src/main.py \
242
+ --model_name linear \
243
+ --system_config_path config/system_config.yaml \
244
+ --model_config_path config/linear.yaml \
245
+ --train_set data/train \
246
+ --val_set data/val \
247
+ --test_set data/test \
248
+ --exp_id linear_baseline
249
+ ```
250
+
251
+ **FortiTran**:
252
+ ```bash
253
+ python src/main.py \
254
+ --model_name fortitran \
255
+ --system_config_path config/system_config.yaml \
256
+ --model_config_path config/fortitran.yaml \
257
+ --train_set data/train \
258
+ --val_set data/val \
259
+ --test_set data/test \
260
+ --exp_id fortitran_experiment
261
+ ```
262
+
263
+ **AdaFortiTran**:
264
+ ```bash
265
+ python src/main.py \
266
+ --model_name adafortitran \
267
+ --system_config_path config/system_config.yaml \
268
+ --model_config_path config/adafortitran.yaml \
269
+ --train_set data/train \
270
+ --val_set data/val \
271
+ --test_set data/test \
272
+ --exp_id adafortitran_experiment
273
+ ```
274
+
275
+ ### Resume Training
276
+
277
  ```bash
278
+ python src/main.py \
279
+ --model_name adafortitran \
280
+ --system_config_path config/system_config.yaml \
281
+ --model_config_path config/adafortitran.yaml \
282
+ --train_set data/train \
283
+ --val_set data/val \
284
+ --test_set data/test \
285
+ --exp_id resumed_experiment \
286
+ --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
287
  ```
288
 
289
+ ### Hyperparameter Tuning
290
+
291
+ ```bash
292
+ python src/main.py \
293
+ --model_name adafortitran \
294
+ --system_config_path config/system_config.yaml \
295
+ --model_config_path config/adafortitran.yaml \
296
+ --train_set data/train \
297
+ --val_set data/val \
298
+ --test_set data/test \
299
+ --exp_id hyperparameter_tuning \
300
+ --batch_size 64 \
301
+ --lr 1e-3 \
302
+ --max_epoch 50 \
303
+ --patience 5 \
304
+ --weight_decay 1e-5 \
305
+ --gradient_clip_val 0.5 \
306
+ --use_mixed_precision \
307
+ --test_every_n 5
308
+ ```
309
+
310
+ ## 📈 Monitoring and Logging
311
+
312
+ ### TensorBoard Integration
313
+
314
+ Training automatically logs metrics to TensorBoard:
315
 
316
  ```bash
317
+ tensorboard --logdir runs/
318
  ```
319
 
320
+ Available metrics:
321
+ - Training/validation loss
322
+ - Learning rate
323
+ - Test performance across conditions
324
+ - Error visualizations
325
+ - Model hyperparameters
326
+
327
+ ### Log Files
328
+
329
+ Training logs are saved to:
330
+ - `logs/training_{exp_id}.log`: Python logging output
331
+ - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
332
+
333
+ ## 🧪 Testing and Evaluation
334
+
335
+ ### Automatic Testing
336
+
337
+ The training pipeline automatically evaluates models on:
338
+ - **DS (Delay Spread)**: Varying delay spread conditions
339
+ - **SNR**: Different signal-to-noise ratios
340
+ - **MDS (Multi-Doppler)**: Various Doppler shift scenarios
341
+
342
+ ### Manual Evaluation
343
 
344
+ ```python
345
+ from src.models import AdaFortiTranEstimator
346
+ from src.config import load_config
347
+
348
+ # Load configurations
349
+ system_config, model_config = load_config(
350
+ 'config/system_config.yaml',
351
+ 'config/adafortitran.yaml'
352
+ )
353
+
354
+ # Initialize model
355
+ model = AdaFortiTranEstimator(system_config, model_config)
356
+
357
+ # Load checkpoint
358
+ checkpoint = torch.load('checkpoint.pt')
359
+ model.load_state_dict(checkpoint['model_state_dict'])
360
+
361
+ # Evaluate
362
+ model.eval()
363
+ # ... evaluation code
364
+ ```
365
+
366
+ ## 🔬 Research and Development
367
+
368
+ ### Adding Custom Callbacks
369
+
370
+ ```python
371
+ from src.main.trainer import Callback, TrainingMetrics
372
+
373
+ class CustomCallback(Callback):
374
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
375
+ # Custom logic here
376
+ print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
377
+ ```
378
+
379
+ ### Extending Models
380
+
381
+ The modular architecture makes it easy to add new model variants:
382
+
383
+ ```python
384
+ from src.models.fortitran import BaseFortiTranEstimator
385
+
386
+ class CustomEstimator(BaseFortiTranEstimator):
387
+ def __init__(self, system_config, model_config):
388
+ super().__init__(system_config, model_config, use_channel_adaptation=True)
389
+ # Add custom components
390
+ ```
391
+
392
+ ## 🐛 Troubleshooting
393
+
394
+ ### Common Issues
395
+
396
+ **CUDA Out of Memory**:
397
+ - Reduce batch size: `--batch_size 32`
398
+ - Enable mixed precision: `--use_mixed_precision`
399
+ - Reduce number of workers: `--num_workers 2`
400
+
401
+ **Slow Training**:
402
+ - Increase number of workers: `--num_workers 8`
403
+ - Enable pin memory: `--pin_memory`
404
+ - Use mixed precision: `--use_mixed_precision`
405
+
406
+ **Poor Convergence**:
407
+ - Adjust learning rate: `--lr 1e-4`
408
+ - Add gradient clipping: `--gradient_clip_val 1.0`
409
+ - Increase patience: `--patience 10`
410
+
411
+ ### Getting Help
412
+
413
+ 1. Check the logs in `logs/training_{exp_id}.log`
414
+ 2. Verify dataset format matches requirements
415
+ 3. Ensure all dependencies are installed correctly
416
+ 4. Check TensorBoard for training curves
417
+
418
+ ## 📚 Citation
419
+
420
+ If you use this code in your research, please cite:
421
 
422
  ```bibtex
423
  @misc{guler2025adafortitranadaptivetransformermodel,
 
431
  }
432
  ```
433
 
434
+ ## 📄 License
435
 
436
  This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
437
+
438
+ Copyright (c) 2025 [Berkay Guler/University of California, Irvine]
README_original.md DELETED
@@ -1,438 +0,0 @@
1
- # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
2
-
3
- [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
4
- [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
5
- [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
6
-
7
- Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
8
-
9
- ## 📖 Overview
10
-
11
- AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
12
-
13
- ### Key Features
14
- - **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
15
- - **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
16
- - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
17
- - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
18
- - **🚀 Production Ready**: Comprehensive training pipeline with advanced features
19
-
20
- ## 🏗️ Architecture
21
-
22
- The project implements three model variants:
23
-
24
- 1. **Linear Estimator**: Simple learned linear transformation baseline
25
- 2. **FortiTran**: Fixed transformer-based channel estimator
26
- 3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
27
-
28
- ### Model Comparison
29
-
30
- | Model | Channel Adaptation | Complexity | Performance |
31
- |-------|-------------------|------------|-------------|
32
- | Linear | ❌ | Low | Baseline |
33
- | FortiTran | ❌ | Medium | Good |
34
- | AdaFortiTran | ✅ | High | **Best** |
35
-
36
- ## 🚀 Quick Start
37
-
38
- ### Installation
39
-
40
- 1. **Clone the repository**:
41
- ```bash
42
- git clone https://github.com/your-username/AdaFortiTran.git
43
- cd AdaFortiTran
44
- ```
45
-
46
- 2. **Install dependencies**:
47
- ```bash
48
- pip install -r requirements.txt
49
- ```
50
-
51
- 3. **Verify installation**:
52
- ```bash
53
- python -c "import torch; print(f'PyTorch {torch.__version__}')"
54
- ```
55
-
56
- ### Basic Training
57
-
58
- Train an AdaFortiTran model with default settings:
59
-
60
- ```bash
61
- python src/main.py \
62
- --model_name adafortitran \
63
- --system_config_path config/system_config.yaml \
64
- --model_config_path config/adafortitran.yaml \
65
- --train_set data/train \
66
- --val_set data/val \
67
- --test_set data/test \
68
- --exp_id my_experiment
69
- ```
70
-
71
- ### Advanced Training
72
-
73
- Use all available features for optimal performance:
74
-
75
- ```bash
76
- python src/main.py \
77
- --model_name adafortitran \
78
- --system_config_path config/system_config.yaml \
79
- --model_config_path config/adafortitran.yaml \
80
- --train_set data/train \
81
- --val_set data/val \
82
- --test_set data/test \
83
- --exp_id advanced_experiment \
84
- --batch_size 128 \
85
- --lr 5e-4 \
86
- --max_epoch 100 \
87
- --patience 10 \
88
- --weight_decay 1e-4 \
89
- --gradient_clip_val 1.0 \
90
- --use_mixed_precision \
91
- --save_every_n_epochs 5 \
92
- --num_workers 8 \
93
- --test_every_n 5
94
- ```
95
-
96
- ## 📁 Project Structure
97
-
98
- ```
99
- AdaFortiTran/
100
- ├── config/ # Configuration files
101
- │ ├── system_config.yaml # OFDM system parameters
102
- │ ├── adafortitran.yaml # AdaFortiTran model config
103
- │ ├── fortitran.yaml # FortiTran model config
104
- │ └── linear.yaml # Linear model config
105
- ├── data/ # Dataset directory
106
- │ ├── train/ # Training data
107
- │ ├── val/ # Validation data
108
- │ └── test/ # Test data (DS, MDS, SNR sets)
109
- ├── src/ # Source code
110
- │ ├── main/ # Training pipeline
111
- │ │ ├── trainer.py # Enhanced ModelTrainer
112
- │ │ └── parser.py # Command-line argument parser
113
- │ ├── models/ # Model implementations
114
- │ │ ├── adafortitran.py # AdaFortiTran model
115
- │ │ ├── fortitran.py # FortiTran model
116
- │ │ ├── linear.py # Linear model
117
- │ │ └── blocks/ # Model building blocks
118
- │ ├── data/ # Data loading
119
- │ │ └── dataset.py # Dataset and DataLoader classes
120
- │ ├── config/ # Configuration management
121
- │ │ ├── config_loader.py # YAML configuration loader
122
- │ │ └── schemas.py # Pydantic validation schemas
123
- │ └── utils.py # Utility functions
124
- ├── requirements.txt # Python dependencies
125
- ├── README.md # This file
126
- ```
127
-
128
- ## ⚙️ Configuration
129
-
130
- ### System Configuration (`config/system_config.yaml`)
131
-
132
- Defines OFDM system parameters:
133
-
134
- ```yaml
135
- ofdm:
136
- num_scs: 120 # Number of subcarriers
137
- num_symbols: 14 # Number of OFDM symbols
138
-
139
- pilot:
140
- num_scs: 12 # Number of pilot subcarriers
141
- num_symbols: 2 # Number of pilot symbols
142
- ```
143
-
144
- ### Model Configuration (`config/adafortitran.yaml`)
145
-
146
- Defines model architecture parameters:
147
-
148
- ```yaml
149
- model_type: 'adafortitran'
150
- patch_size: [3, 2] # Patch dimensions
151
- num_layers: 6 # Transformer layers
152
- model_dim: 128 # Model dimension
153
- num_head: 4 # Attention heads
154
- activation: 'gelu' # Activation function
155
- dropout: 0.1 # Dropout rate
156
- max_seq_len: 512 # Maximum sequence length
157
- pos_encoding_type: 'learnable' # Positional encoding
158
- channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
159
- adaptive_token_length: 6 # Adaptive token length
160
- ```
161
-
162
- ## 🎯 Training Features
163
-
164
- ### Advanced Training Options
165
-
166
- | Feature | Description | Default |
167
- |---------|-------------|---------|
168
- | `--use_mixed_precision` | Enable mixed precision training | False |
169
- | `--gradient_clip_val` | Gradient clipping value | None |
170
- | `--weight_decay` | Weight decay for optimizer | 0.0 |
171
- | `--save_checkpoints` | Enable model checkpointing | True |
172
- | `--save_best_only` | Save only best model | True |
173
- | `--resume_from_checkpoint` | Resume from checkpoint | None |
174
- | `--num_workers` | Data loading workers | 4 |
175
- | `--pin_memory` | Pin memory for GPU | True |
176
-
177
- ### Callback System
178
-
179
- The training pipeline includes an extensible callback system:
180
-
181
- - **TensorBoard Logging**: Automatic metric tracking and visualization
182
- - **Checkpoint Management**: Flexible checkpoint saving strategies
183
- - **Custom Callbacks**: Easy to add new logging or monitoring systems
184
-
185
- ### Performance Optimizations
186
-
187
- - **Mixed Precision Training**: Faster training on modern GPUs
188
- - **Optimized Data Loading**: Configurable workers and memory pinning
189
- - **Gradient Clipping**: Stable training with configurable clipping
190
- - **Early Stopping**: Automatic training termination on plateau
191
-
192
- ## 📊 Dataset Format
193
-
194
- ### Expected File Structure
195
-
196
- ```
197
- data/
198
- ├── train/
199
- │ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
200
- │ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
201
- │ └── ...
202
- ├── val/
203
- │ └── ...
204
- └── test/
205
- ├── DS_test_set/ # Delay Spread tests
206
- │ ├── DS_50/
207
- │ ├── DS_100/
208
- │ └── ...
209
- ├── SNR_test_set/ # SNR tests
210
- │ ├── SNR_10/
211
- │ ├── SNR_20/
212
- │ └── ...
213
- └── MDS_test_set/ # Multi-Doppler tests
214
- ├── DOP_200/
215
- ├── DOP_400/
216
- └── ...
217
- ```
218
-
219
- ### File Naming Convention
220
-
221
- Files must follow the pattern:
222
- ```
223
- {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
224
- ```
225
-
226
- Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
227
-
228
- ### Data Format
229
-
230
- Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
231
- - `H[:, :, 0]`: Ground truth channel (complex values)
232
- - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
233
- - `H[:, :, 2]`: Reserved for future use
234
-
235
- ## 🔧 Usage Examples
236
-
237
- ### Training Different Models
238
-
239
- **Linear Estimator**:
240
- ```bash
241
- python src/main.py \
242
- --model_name linear \
243
- --system_config_path config/system_config.yaml \
244
- --model_config_path config/linear.yaml \
245
- --train_set data/train \
246
- --val_set data/val \
247
- --test_set data/test \
248
- --exp_id linear_baseline
249
- ```
250
-
251
- **FortiTran**:
252
- ```bash
253
- python src/main.py \
254
- --model_name fortitran \
255
- --system_config_path config/system_config.yaml \
256
- --model_config_path config/fortitran.yaml \
257
- --train_set data/train \
258
- --val_set data/val \
259
- --test_set data/test \
260
- --exp_id fortitran_experiment
261
- ```
262
-
263
- **AdaFortiTran**:
264
- ```bash
265
- python src/main.py \
266
- --model_name adafortitran \
267
- --system_config_path config/system_config.yaml \
268
- --model_config_path config/adafortitran.yaml \
269
- --train_set data/train \
270
- --val_set data/val \
271
- --test_set data/test \
272
- --exp_id adafortitran_experiment
273
- ```
274
-
275
- ### Resume Training
276
-
277
- ```bash
278
- python src/main.py \
279
- --model_name adafortitran \
280
- --system_config_path config/system_config.yaml \
281
- --model_config_path config/adafortitran.yaml \
282
- --train_set data/train \
283
- --val_set data/val \
284
- --test_set data/test \
285
- --exp_id resumed_experiment \
286
- --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
287
- ```
288
-
289
- ### Hyperparameter Tuning
290
-
291
- ```bash
292
- python src/main.py \
293
- --model_name adafortitran \
294
- --system_config_path config/system_config.yaml \
295
- --model_config_path config/adafortitran.yaml \
296
- --train_set data/train \
297
- --val_set data/val \
298
- --test_set data/test \
299
- --exp_id hyperparameter_tuning \
300
- --batch_size 64 \
301
- --lr 1e-3 \
302
- --max_epoch 50 \
303
- --patience 5 \
304
- --weight_decay 1e-5 \
305
- --gradient_clip_val 0.5 \
306
- --use_mixed_precision \
307
- --test_every_n 5
308
- ```
309
-
310
- ## 📈 Monitoring and Logging
311
-
312
- ### TensorBoard Integration
313
-
314
- Training automatically logs metrics to TensorBoard:
315
-
316
- ```bash
317
- tensorboard --logdir runs/
318
- ```
319
-
320
- Available metrics:
321
- - Training/validation loss
322
- - Learning rate
323
- - Test performance across conditions
324
- - Error visualizations
325
- - Model hyperparameters
326
-
327
- ### Log Files
328
-
329
- Training logs are saved to:
330
- - `logs/training_{exp_id}.log`: Python logging output
331
- - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
332
-
333
- ## 🧪 Testing and Evaluation
334
-
335
- ### Automatic Testing
336
-
337
- The training pipeline automatically evaluates models on:
338
- - **DS (Delay Spread)**: Varying delay spread conditions
339
- - **SNR**: Different signal-to-noise ratios
340
- - **MDS (Multi-Doppler)**: Various Doppler shift scenarios
341
-
342
- ### Manual Evaluation
343
-
344
- ```python
345
- from src.models import AdaFortiTranEstimator
346
- from src.config import load_config
347
-
348
- # Load configurations
349
- system_config, model_config = load_config(
350
- 'config/system_config.yaml',
351
- 'config/adafortitran.yaml'
352
- )
353
-
354
- # Initialize model
355
- model = AdaFortiTranEstimator(system_config, model_config)
356
-
357
- # Load checkpoint
358
- checkpoint = torch.load('checkpoint.pt')
359
- model.load_state_dict(checkpoint['model_state_dict'])
360
-
361
- # Evaluate
362
- model.eval()
363
- # ... evaluation code
364
- ```
365
-
366
- ## 🔬 Research and Development
367
-
368
- ### Adding Custom Callbacks
369
-
370
- ```python
371
- from src.main.trainer import Callback, TrainingMetrics
372
-
373
- class CustomCallback(Callback):
374
- def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
375
- # Custom logic here
376
- print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
377
- ```
378
-
379
- ### Extending Models
380
-
381
- The modular architecture makes it easy to add new model variants:
382
-
383
- ```python
384
- from src.models.fortitran import BaseFortiTranEstimator
385
-
386
- class CustomEstimator(BaseFortiTranEstimator):
387
- def __init__(self, system_config, model_config):
388
- super().__init__(system_config, model_config, use_channel_adaptation=True)
389
- # Add custom components
390
- ```
391
-
392
- ## 🐛 Troubleshooting
393
-
394
- ### Common Issues
395
-
396
- **CUDA Out of Memory**:
397
- - Reduce batch size: `--batch_size 32`
398
- - Enable mixed precision: `--use_mixed_precision`
399
- - Reduce number of workers: `--num_workers 2`
400
-
401
- **Slow Training**:
402
- - Increase number of workers: `--num_workers 8`
403
- - Enable pin memory: `--pin_memory`
404
- - Use mixed precision: `--use_mixed_precision`
405
-
406
- **Poor Convergence**:
407
- - Adjust learning rate: `--lr 1e-4`
408
- - Add gradient clipping: `--gradient_clip_val 1.0`
409
- - Increase patience: `--patience 10`
410
-
411
- ### Getting Help
412
-
413
- 1. Check the logs in `logs/training_{exp_id}.log`
414
- 2. Verify dataset format matches requirements
415
- 3. Ensure all dependencies are installed correctly
416
- 4. Check TensorBoard for training curves
417
-
418
- ## 📚 Citation
419
-
420
- If you use this code in your research, please cite:
421
-
422
- ```bibtex
423
- @misc{guler2025adafortitranadaptivetransformermodel,
424
- title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
425
- author={Berkay Guler and Hamid Jafarkhani},
426
- year={2025},
427
- eprint={2505.09076},
428
- archivePrefix={arXiv},
429
- primaryClass={cs.LG},
430
- url={https://arxiv.org/abs/2505.09076},
431
- }
432
- ```
433
-
434
- ## 📄 License
435
-
436
- This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
437
-
438
- Copyright (c) 2025 [Berkay Guler/University of California, Irvine]