BerkIGuler commited on
Commit
7e105b2
Β·
1 Parent(s): 301f5ca

Initial commit for Hugging Face

Browse files
Files changed (4) hide show
  1. .gitattributes +5 -0
  2. .huggingfaceignore +34 -0
  3. README.md +28 -394
  4. README_original.md +438 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.mat filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
4
+ *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.bin filter=lfs diff=lfs merge=lfs -text
.huggingfaceignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore large data files during upload
2
+ data/train/
3
+ data/val/
4
+ data/test/
5
+
6
+ # Ignore model checkpoints and logs
7
+ *.ckpt
8
+ *.pth
9
+ *.pt
10
+ logs/
11
+ runs/
12
+ checkpoints/
13
+
14
+ # Ignore temporary files
15
+ __pycache__/
16
+ *.pyc
17
+ *.pyo
18
+ *.pyd
19
+ .Python
20
+ *.so
21
+ .DS_Store
22
+ Thumbs.db
23
+
24
+ # Ignore IDE files
25
+ .vscode/
26
+ .idea/
27
+ *.swp
28
+ *.swo
29
+
30
+ # Ignore environment files
31
+ .env
32
+ .venv/
33
+ venv/
34
+ env/
README.md CHANGED
@@ -1,23 +1,35 @@
1
- # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
2
-
3
- [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
4
- [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
5
- [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
8
 
9
- ## πŸ“– Overview
10
 
11
  AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
12
 
13
- ### Key Features
 
14
  - **πŸ”„ Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
15
  - **⚑ High Performance**: State-of-the-art results on OFDM channel estimation tasks
16
  - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
17
  - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
18
  - **πŸš€ Production Ready**: Comprehensive training pipeline with advanced features
19
 
20
- ## πŸ—οΈ Architecture
21
 
22
  The project implements three model variants:
23
 
@@ -25,399 +37,23 @@ The project implements three model variants:
25
  2. **FortiTran**: Fixed transformer-based channel estimator
26
  3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
27
 
28
- ### Model Comparison
29
-
30
- | Model | Channel Adaptation | Complexity | Performance |
31
- |-------|-------------------|------------|-------------|
32
- | Linear | ❌ | Low | Baseline |
33
- | FortiTran | ❌ | Medium | Good |
34
- | AdaFortiTran | βœ… | High | **Best** |
35
-
36
- ## πŸš€ Quick Start
37
 
38
  ### Installation
39
 
40
- 1. **Clone the repository**:
41
- ```bash
42
- git clone https://github.com/your-username/AdaFortiTran.git
43
- cd AdaFortiTran
44
- ```
45
-
46
- 2. **Install dependencies**:
47
- ```bash
48
- pip install -r requirements.txt
49
- ```
50
-
51
- 3. **Verify installation**:
52
- ```bash
53
- python -c "import torch; print(f'PyTorch {torch.__version__}')"
54
- ```
55
-
56
- ### Basic Training
57
-
58
- Train an AdaFortiTran model with default settings:
59
-
60
- ```bash
61
- python src/main.py \
62
- --model_name adafortitran \
63
- --system_config_path config/system_config.yaml \
64
- --model_config_path config/adafortitran.yaml \
65
- --train_set data/train \
66
- --val_set data/val \
67
- --test_set data/test \
68
- --exp_id my_experiment
69
- ```
70
-
71
- ### Advanced Training
72
-
73
- Use all available features for optimal performance:
74
-
75
- ```bash
76
- python src/main.py \
77
- --model_name adafortitran \
78
- --system_config_path config/system_config.yaml \
79
- --model_config_path config/adafortitran.yaml \
80
- --train_set data/train \
81
- --val_set data/val \
82
- --test_set data/test \
83
- --exp_id advanced_experiment \
84
- --batch_size 128 \
85
- --lr 5e-4 \
86
- --max_epoch 100 \
87
- --patience 10 \
88
- --weight_decay 1e-4 \
89
- --gradient_clip_val 1.0 \
90
- --use_mixed_precision \
91
- --save_every_n_epochs 5 \
92
- --num_workers 8 \
93
- --test_every_n 5
94
- ```
95
-
96
- ## πŸ“ Project Structure
97
-
98
- ```
99
- AdaFortiTran/
100
- β”œβ”€β”€ config/ # Configuration files
101
- β”‚ β”œβ”€β”€ system_config.yaml # OFDM system parameters
102
- β”‚ β”œβ”€β”€ adafortitran.yaml # AdaFortiTran model config
103
- β”‚ β”œβ”€β”€ fortitran.yaml # FortiTran model config
104
- β”‚ └── linear.yaml # Linear model config
105
- β”œβ”€β”€ data/ # Dataset directory
106
- β”‚ β”œβ”€β”€ train/ # Training data
107
- β”‚ β”œβ”€β”€ val/ # Validation data
108
- β”‚ └── test/ # Test data (DS, MDS, SNR sets)
109
- β”œβ”€β”€ src/ # Source code
110
- β”‚ β”œβ”€β”€ main/ # Training pipeline
111
- β”‚ β”‚ β”œβ”€β”€ trainer.py # Enhanced ModelTrainer
112
- β”‚ β”‚ └── parser.py # Command-line argument parser
113
- β”‚ β”œβ”€β”€ models/ # Model implementations
114
- β”‚ β”‚ β”œβ”€β”€ adafortitran.py # AdaFortiTran model
115
- β”‚ β”‚ β”œβ”€β”€ fortitran.py # FortiTran model
116
- β”‚ β”‚ β”œβ”€β”€ linear.py # Linear model
117
- β”‚ β”‚ └── blocks/ # Model building blocks
118
- β”‚ β”œβ”€β”€ data/ # Data loading
119
- β”‚ β”‚ └── dataset.py # Dataset and DataLoader classes
120
- β”‚ β”œβ”€β”€ config/ # Configuration management
121
- β”‚ β”‚ β”œβ”€β”€ config_loader.py # YAML configuration loader
122
- β”‚ β”‚ └── schemas.py # Pydantic validation schemas
123
- β”‚ └── utils.py # Utility functions
124
- β”œβ”€β”€ requirements.txt # Python dependencies
125
- β”œβ”€β”€ README.md # This file
126
- ```
127
-
128
- ## βš™οΈ Configuration
129
-
130
- ### System Configuration (`config/system_config.yaml`)
131
-
132
- Defines OFDM system parameters:
133
-
134
- ```yaml
135
- ofdm:
136
- num_scs: 120 # Number of subcarriers
137
- num_symbols: 14 # Number of OFDM symbols
138
-
139
- pilot:
140
- num_scs: 12 # Number of pilot subcarriers
141
- num_symbols: 2 # Number of pilot symbols
142
- ```
143
-
144
- ### Model Configuration (`config/adafortitran.yaml`)
145
-
146
- Defines model architecture parameters:
147
-
148
- ```yaml
149
- model_type: 'adafortitran'
150
- patch_size: [3, 2] # Patch dimensions
151
- num_layers: 6 # Transformer layers
152
- model_dim: 128 # Model dimension
153
- num_head: 4 # Attention heads
154
- activation: 'gelu' # Activation function
155
- dropout: 0.1 # Dropout rate
156
- max_seq_len: 512 # Maximum sequence length
157
- pos_encoding_type: 'learnable' # Positional encoding
158
- channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
159
- adaptive_token_length: 6 # Adaptive token length
160
- ```
161
-
162
- ## 🎯 Training Features
163
-
164
- ### Advanced Training Options
165
-
166
- | Feature | Description | Default |
167
- |---------|-------------|---------|
168
- | `--use_mixed_precision` | Enable mixed precision training | False |
169
- | `--gradient_clip_val` | Gradient clipping value | None |
170
- | `--weight_decay` | Weight decay for optimizer | 0.0 |
171
- | `--save_checkpoints` | Enable model checkpointing | True |
172
- | `--save_best_only` | Save only best model | True |
173
- | `--resume_from_checkpoint` | Resume from checkpoint | None |
174
- | `--num_workers` | Data loading workers | 4 |
175
- | `--pin_memory` | Pin memory for GPU | True |
176
-
177
- ### Callback System
178
-
179
- The training pipeline includes an extensible callback system:
180
-
181
- - **TensorBoard Logging**: Automatic metric tracking and visualization
182
- - **Checkpoint Management**: Flexible checkpoint saving strategies
183
- - **Custom Callbacks**: Easy to add new logging or monitoring systems
184
-
185
- ### Performance Optimizations
186
-
187
- - **Mixed Precision Training**: Faster training on modern GPUs
188
- - **Optimized Data Loading**: Configurable workers and memory pinning
189
- - **Gradient Clipping**: Stable training with configurable clipping
190
- - **Early Stopping**: Automatic training termination on plateau
191
-
192
- ## πŸ“Š Dataset Format
193
-
194
- ### Expected File Structure
195
-
196
- ```
197
- data/
198
- β”œβ”€β”€ train/
199
- β”‚ β”œβ”€β”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
200
- β”‚ β”œβ”€β”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
201
- β”‚ └── ...
202
- β”œβ”€β”€ val/
203
- β”‚ └── ...
204
- └── test/
205
- β”œβ”€β”€ DS_test_set/ # Delay Spread tests
206
- β”‚ β”œβ”€β”€ DS_50/
207
- β”‚ β”œβ”€β”€ DS_100/
208
- β”‚ └── ...
209
- β”œβ”€β”€ SNR_test_set/ # SNR tests
210
- β”‚ β”œβ”€β”€ SNR_10/
211
- β”‚ β”œβ”€β”€ SNR_20/
212
- β”‚ └── ...
213
- └── MDS_test_set/ # Multi-Doppler tests
214
- β”œβ”€β”€ DOP_200/
215
- β”œβ”€β”€ DOP_400/
216
- └── ...
217
- ```
218
-
219
- ### File Naming Convention
220
-
221
- Files must follow the pattern:
222
- ```
223
- {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
224
- ```
225
-
226
- Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
227
-
228
- ### Data Format
229
-
230
- Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
231
- - `H[:, :, 0]`: Ground truth channel (complex values)
232
- - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
233
- - `H[:, :, 2]`: Reserved for future use
234
-
235
- ## πŸ”§ Usage Examples
236
-
237
- ### Training Different Models
238
-
239
- **Linear Estimator**:
240
- ```bash
241
- python src/main.py \
242
- --model_name linear \
243
- --system_config_path config/system_config.yaml \
244
- --model_config_path config/linear.yaml \
245
- --train_set data/train \
246
- --val_set data/val \
247
- --test_set data/test \
248
- --exp_id linear_baseline
249
- ```
250
-
251
- **FortiTran**:
252
- ```bash
253
- python src/main.py \
254
- --model_name fortitran \
255
- --system_config_path config/system_config.yaml \
256
- --model_config_path config/fortitran.yaml \
257
- --train_set data/train \
258
- --val_set data/val \
259
- --test_set data/test \
260
- --exp_id fortitran_experiment
261
- ```
262
-
263
- **AdaFortiTran**:
264
- ```bash
265
- python src/main.py \
266
- --model_name adafortitran \
267
- --system_config_path config/system_config.yaml \
268
- --model_config_path config/adafortitran.yaml \
269
- --train_set data/train \
270
- --val_set data/val \
271
- --test_set data/test \
272
- --exp_id adafortitran_experiment
273
- ```
274
-
275
- ### Resume Training
276
-
277
  ```bash
278
- python src/main.py \
279
- --model_name adafortitran \
280
- --system_config_path config/system_config.yaml \
281
- --model_config_path config/adafortitran.yaml \
282
- --train_set data/train \
283
- --val_set data/val \
284
- --test_set data/test \
285
- --exp_id resumed_experiment \
286
- --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
287
  ```
288
 
289
- ### Hyperparameter Tuning
290
-
291
- ```bash
292
- python src/main.py \
293
- --model_name adafortitran \
294
- --system_config_path config/system_config.yaml \
295
- --model_config_path config/adafortitran.yaml \
296
- --train_set data/train \
297
- --val_set data/val \
298
- --test_set data/test \
299
- --exp_id hyperparameter_tuning \
300
- --batch_size 64 \
301
- --lr 1e-3 \
302
- --max_epoch 50 \
303
- --patience 5 \
304
- --weight_decay 1e-5 \
305
- --gradient_clip_val 0.5 \
306
- --use_mixed_precision \
307
- --test_every_n 5
308
- ```
309
-
310
- ## πŸ“ˆ Monitoring and Logging
311
-
312
- ### TensorBoard Integration
313
-
314
- Training automatically logs metrics to TensorBoard:
315
 
316
  ```bash
317
- tensorboard --logdir runs/
318
  ```
319
 
320
- Available metrics:
321
- - Training/validation loss
322
- - Learning rate
323
- - Test performance across conditions
324
- - Error visualizations
325
- - Model hyperparameters
326
-
327
- ### Log Files
328
-
329
- Training logs are saved to:
330
- - `logs/training_{exp_id}.log`: Python logging output
331
- - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
332
-
333
- ## πŸ§ͺ Testing and Evaluation
334
-
335
- ### Automatic Testing
336
-
337
- The training pipeline automatically evaluates models on:
338
- - **DS (Delay Spread)**: Varying delay spread conditions
339
- - **SNR**: Different signal-to-noise ratios
340
- - **MDS (Multi-Doppler)**: Various Doppler shift scenarios
341
-
342
- ### Manual Evaluation
343
 
344
- ```python
345
- from src.models import AdaFortiTranEstimator
346
- from src.config import load_config
347
-
348
- # Load configurations
349
- system_config, model_config = load_config(
350
- 'config/system_config.yaml',
351
- 'config/adafortitran.yaml'
352
- )
353
-
354
- # Initialize model
355
- model = AdaFortiTranEstimator(system_config, model_config)
356
-
357
- # Load checkpoint
358
- checkpoint = torch.load('checkpoint.pt')
359
- model.load_state_dict(checkpoint['model_state_dict'])
360
-
361
- # Evaluate
362
- model.eval()
363
- # ... evaluation code
364
- ```
365
-
366
- ## πŸ”¬ Research and Development
367
-
368
- ### Adding Custom Callbacks
369
-
370
- ```python
371
- from src.main.trainer import Callback, TrainingMetrics
372
-
373
- class CustomCallback(Callback):
374
- def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
375
- # Custom logic here
376
- print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
377
- ```
378
-
379
- ### Extending Models
380
-
381
- The modular architecture makes it easy to add new model variants:
382
-
383
- ```python
384
- from src.models.fortitran import BaseFortiTranEstimator
385
-
386
- class CustomEstimator(BaseFortiTranEstimator):
387
- def __init__(self, system_config, model_config):
388
- super().__init__(system_config, model_config, use_channel_adaptation=True)
389
- # Add custom components
390
- ```
391
-
392
- ## πŸ› Troubleshooting
393
-
394
- ### Common Issues
395
-
396
- **CUDA Out of Memory**:
397
- - Reduce batch size: `--batch_size 32`
398
- - Enable mixed precision: `--use_mixed_precision`
399
- - Reduce number of workers: `--num_workers 2`
400
-
401
- **Slow Training**:
402
- - Increase number of workers: `--num_workers 8`
403
- - Enable pin memory: `--pin_memory`
404
- - Use mixed precision: `--use_mixed_precision`
405
-
406
- **Poor Convergence**:
407
- - Adjust learning rate: `--lr 1e-4`
408
- - Add gradient clipping: `--gradient_clip_val 1.0`
409
- - Increase patience: `--patience 10`
410
-
411
- ### Getting Help
412
-
413
- 1. Check the logs in `logs/training_{exp_id}.log`
414
- 2. Verify dataset format matches requirements
415
- 3. Ensure all dependencies are installed correctly
416
- 4. Check TensorBoard for training curves
417
-
418
- ## πŸ“š Citation
419
-
420
- If you use this code in your research, please cite:
421
 
422
  ```bibtex
423
  @misc{guler2025adafortitranadaptivetransformermodel,
@@ -431,8 +67,6 @@ If you use this code in your research, please cite:
431
  }
432
  ```
433
 
434
- ## πŸ“„ License
435
 
436
  This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
437
-
438
- Copyright (c) 2025 [Berkay Guler/University of California, Irvine]
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - pytorch
6
+ - transformer
7
+ - channel-estimation
8
+ - ofdm
9
+ - wireless
10
+ - adaptive
11
+ license: mit
12
+ datasets:
13
+ - custom
14
+ metrics:
15
+ - mse
16
+ ---
17
 
18
+ # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
19
 
20
+ ## Model Description
21
 
22
  AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
23
 
24
+ ## Key Features
25
+
26
  - **πŸ”„ Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
27
  - **⚑ High Performance**: State-of-the-art results on OFDM channel estimation tasks
28
  - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
29
  - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
30
  - **πŸš€ Production Ready**: Comprehensive training pipeline with advanced features
31
 
32
+ ## Architecture
33
 
34
  The project implements three model variants:
35
 
 
37
  2. **FortiTran**: Fixed transformer-based channel estimator
38
  3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
39
 
40
+ ## Usage
 
 
 
 
 
 
 
 
41
 
42
  ### Installation
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ```bash
45
+ pip install -r requirements.txt
 
 
 
 
 
 
 
 
46
  ```
47
 
48
+ ### Training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  ```bash
51
+ python src/main.py --model_name adafortitran --system_config_path config/system_config.yaml --model_config_path config/adafortitran.yaml --train_set data/train --val_set data/val --test_set data/test --exp_id my_experiment
52
  ```
53
 
54
+ ## Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ If you use this model in your research, please cite:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  ```bibtex
59
  @misc{guler2025adafortitranadaptivetransformermodel,
 
67
  }
68
  ```
69
 
70
+ ## License
71
 
72
  This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
 
 
README_original.md ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
2
+
3
+ [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
4
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
5
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
6
+
7
+ Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
8
+
9
+ ## πŸ“– Overview
10
+
11
+ AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
12
+
13
+ ### Key Features
14
+ - **πŸ”„ Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
15
+ - **⚑ High Performance**: State-of-the-art results on OFDM channel estimation tasks
16
+ - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
17
+ - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
18
+ - **πŸš€ Production Ready**: Comprehensive training pipeline with advanced features
19
+
20
+ ## πŸ—οΈ Architecture
21
+
22
+ The project implements three model variants:
23
+
24
+ 1. **Linear Estimator**: Simple learned linear transformation baseline
25
+ 2. **FortiTran**: Fixed transformer-based channel estimator
26
+ 3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
27
+
28
+ ### Model Comparison
29
+
30
+ | Model | Channel Adaptation | Complexity | Performance |
31
+ |-------|-------------------|------------|-------------|
32
+ | Linear | ❌ | Low | Baseline |
33
+ | FortiTran | ❌ | Medium | Good |
34
+ | AdaFortiTran | βœ… | High | **Best** |
35
+
36
+ ## πŸš€ Quick Start
37
+
38
+ ### Installation
39
+
40
+ 1. **Clone the repository**:
41
+ ```bash
42
+ git clone https://github.com/your-username/AdaFortiTran.git
43
+ cd AdaFortiTran
44
+ ```
45
+
46
+ 2. **Install dependencies**:
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ 3. **Verify installation**:
52
+ ```bash
53
+ python -c "import torch; print(f'PyTorch {torch.__version__}')"
54
+ ```
55
+
56
+ ### Basic Training
57
+
58
+ Train an AdaFortiTran model with default settings:
59
+
60
+ ```bash
61
+ python src/main.py \
62
+ --model_name adafortitran \
63
+ --system_config_path config/system_config.yaml \
64
+ --model_config_path config/adafortitran.yaml \
65
+ --train_set data/train \
66
+ --val_set data/val \
67
+ --test_set data/test \
68
+ --exp_id my_experiment
69
+ ```
70
+
71
+ ### Advanced Training
72
+
73
+ Use all available features for optimal performance:
74
+
75
+ ```bash
76
+ python src/main.py \
77
+ --model_name adafortitran \
78
+ --system_config_path config/system_config.yaml \
79
+ --model_config_path config/adafortitran.yaml \
80
+ --train_set data/train \
81
+ --val_set data/val \
82
+ --test_set data/test \
83
+ --exp_id advanced_experiment \
84
+ --batch_size 128 \
85
+ --lr 5e-4 \
86
+ --max_epoch 100 \
87
+ --patience 10 \
88
+ --weight_decay 1e-4 \
89
+ --gradient_clip_val 1.0 \
90
+ --use_mixed_precision \
91
+ --save_every_n_epochs 5 \
92
+ --num_workers 8 \
93
+ --test_every_n 5
94
+ ```
95
+
96
+ ## πŸ“ Project Structure
97
+
98
+ ```
99
+ AdaFortiTran/
100
+ β”œβ”€β”€ config/ # Configuration files
101
+ β”‚ β”œβ”€β”€ system_config.yaml # OFDM system parameters
102
+ β”‚ β”œβ”€β”€ adafortitran.yaml # AdaFortiTran model config
103
+ β”‚ β”œβ”€β”€ fortitran.yaml # FortiTran model config
104
+ β”‚ └── linear.yaml # Linear model config
105
+ β”œβ”€β”€ data/ # Dataset directory
106
+ β”‚ β”œβ”€β”€ train/ # Training data
107
+ β”‚ β”œβ”€β”€ val/ # Validation data
108
+ β”‚ └── test/ # Test data (DS, MDS, SNR sets)
109
+ β”œβ”€β”€ src/ # Source code
110
+ β”‚ β”œβ”€β”€ main/ # Training pipeline
111
+ β”‚ β”‚ β”œβ”€β”€ trainer.py # Enhanced ModelTrainer
112
+ β”‚ β”‚ └── parser.py # Command-line argument parser
113
+ β”‚ β”œβ”€β”€ models/ # Model implementations
114
+ β”‚ β”‚ β”œβ”€β”€ adafortitran.py # AdaFortiTran model
115
+ β”‚ β”‚ β”œβ”€β”€ fortitran.py # FortiTran model
116
+ β”‚ β”‚ β”œβ”€β”€ linear.py # Linear model
117
+ β”‚ β”‚ └── blocks/ # Model building blocks
118
+ β”‚ β”œβ”€β”€ data/ # Data loading
119
+ β”‚ β”‚ └── dataset.py # Dataset and DataLoader classes
120
+ β”‚ β”œβ”€β”€ config/ # Configuration management
121
+ β”‚ β”‚ β”œβ”€β”€ config_loader.py # YAML configuration loader
122
+ β”‚ β”‚ └── schemas.py # Pydantic validation schemas
123
+ β”‚ └── utils.py # Utility functions
124
+ β”œβ”€β”€ requirements.txt # Python dependencies
125
+ β”œβ”€β”€ README.md # This file
126
+ ```
127
+
128
+ ## βš™οΈ Configuration
129
+
130
+ ### System Configuration (`config/system_config.yaml`)
131
+
132
+ Defines OFDM system parameters:
133
+
134
+ ```yaml
135
+ ofdm:
136
+ num_scs: 120 # Number of subcarriers
137
+ num_symbols: 14 # Number of OFDM symbols
138
+
139
+ pilot:
140
+ num_scs: 12 # Number of pilot subcarriers
141
+ num_symbols: 2 # Number of pilot symbols
142
+ ```
143
+
144
+ ### Model Configuration (`config/adafortitran.yaml`)
145
+
146
+ Defines model architecture parameters:
147
+
148
+ ```yaml
149
+ model_type: 'adafortitran'
150
+ patch_size: [3, 2] # Patch dimensions
151
+ num_layers: 6 # Transformer layers
152
+ model_dim: 128 # Model dimension
153
+ num_head: 4 # Attention heads
154
+ activation: 'gelu' # Activation function
155
+ dropout: 0.1 # Dropout rate
156
+ max_seq_len: 512 # Maximum sequence length
157
+ pos_encoding_type: 'learnable' # Positional encoding
158
+ channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
159
+ adaptive_token_length: 6 # Adaptive token length
160
+ ```
161
+
162
+ ## 🎯 Training Features
163
+
164
+ ### Advanced Training Options
165
+
166
+ | Feature | Description | Default |
167
+ |---------|-------------|---------|
168
+ | `--use_mixed_precision` | Enable mixed precision training | False |
169
+ | `--gradient_clip_val` | Gradient clipping value | None |
170
+ | `--weight_decay` | Weight decay for optimizer | 0.0 |
171
+ | `--save_checkpoints` | Enable model checkpointing | True |
172
+ | `--save_best_only` | Save only best model | True |
173
+ | `--resume_from_checkpoint` | Resume from checkpoint | None |
174
+ | `--num_workers` | Data loading workers | 4 |
175
+ | `--pin_memory` | Pin memory for GPU | True |
176
+
177
+ ### Callback System
178
+
179
+ The training pipeline includes an extensible callback system:
180
+
181
+ - **TensorBoard Logging**: Automatic metric tracking and visualization
182
+ - **Checkpoint Management**: Flexible checkpoint saving strategies
183
+ - **Custom Callbacks**: Easy to add new logging or monitoring systems
184
+
185
+ ### Performance Optimizations
186
+
187
+ - **Mixed Precision Training**: Faster training on modern GPUs
188
+ - **Optimized Data Loading**: Configurable workers and memory pinning
189
+ - **Gradient Clipping**: Stable training with configurable clipping
190
+ - **Early Stopping**: Automatic training termination on plateau
191
+
192
+ ## πŸ“Š Dataset Format
193
+
194
+ ### Expected File Structure
195
+
196
+ ```
197
+ data/
198
+ β”œβ”€β”€ train/
199
+ β”‚ β”œβ”€β”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
200
+ β”‚ β”œβ”€β”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
201
+ β”‚ └── ...
202
+ β”œβ”€β”€ val/
203
+ β”‚ └── ...
204
+ └── test/
205
+ β”œβ”€β”€ DS_test_set/ # Delay Spread tests
206
+ β”‚ β”œβ”€β”€ DS_50/
207
+ β”‚ β”œβ”€β”€ DS_100/
208
+ β”‚ └── ...
209
+ β”œβ”€β”€ SNR_test_set/ # SNR tests
210
+ β”‚ β”œβ”€β”€ SNR_10/
211
+ β”‚ β”œβ”€β”€ SNR_20/
212
+ β”‚ └── ...
213
+ └── MDS_test_set/ # Multi-Doppler tests
214
+ β”œβ”€β”€ DOP_200/
215
+ β”œβ”€β”€ DOP_400/
216
+ └── ...
217
+ ```
218
+
219
+ ### File Naming Convention
220
+
221
+ Files must follow the pattern:
222
+ ```
223
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
224
+ ```
225
+
226
+ Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
227
+
228
+ ### Data Format
229
+
230
+ Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
231
+ - `H[:, :, 0]`: Ground truth channel (complex values)
232
+ - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
233
+ - `H[:, :, 2]`: Reserved for future use
234
+
235
+ ## πŸ”§ Usage Examples
236
+
237
+ ### Training Different Models
238
+
239
+ **Linear Estimator**:
240
+ ```bash
241
+ python src/main.py \
242
+ --model_name linear \
243
+ --system_config_path config/system_config.yaml \
244
+ --model_config_path config/linear.yaml \
245
+ --train_set data/train \
246
+ --val_set data/val \
247
+ --test_set data/test \
248
+ --exp_id linear_baseline
249
+ ```
250
+
251
+ **FortiTran**:
252
+ ```bash
253
+ python src/main.py \
254
+ --model_name fortitran \
255
+ --system_config_path config/system_config.yaml \
256
+ --model_config_path config/fortitran.yaml \
257
+ --train_set data/train \
258
+ --val_set data/val \
259
+ --test_set data/test \
260
+ --exp_id fortitran_experiment
261
+ ```
262
+
263
+ **AdaFortiTran**:
264
+ ```bash
265
+ python src/main.py \
266
+ --model_name adafortitran \
267
+ --system_config_path config/system_config.yaml \
268
+ --model_config_path config/adafortitran.yaml \
269
+ --train_set data/train \
270
+ --val_set data/val \
271
+ --test_set data/test \
272
+ --exp_id adafortitran_experiment
273
+ ```
274
+
275
+ ### Resume Training
276
+
277
+ ```bash
278
+ python src/main.py \
279
+ --model_name adafortitran \
280
+ --system_config_path config/system_config.yaml \
281
+ --model_config_path config/adafortitran.yaml \
282
+ --train_set data/train \
283
+ --val_set data/val \
284
+ --test_set data/test \
285
+ --exp_id resumed_experiment \
286
+ --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
287
+ ```
288
+
289
+ ### Hyperparameter Tuning
290
+
291
+ ```bash
292
+ python src/main.py \
293
+ --model_name adafortitran \
294
+ --system_config_path config/system_config.yaml \
295
+ --model_config_path config/adafortitran.yaml \
296
+ --train_set data/train \
297
+ --val_set data/val \
298
+ --test_set data/test \
299
+ --exp_id hyperparameter_tuning \
300
+ --batch_size 64 \
301
+ --lr 1e-3 \
302
+ --max_epoch 50 \
303
+ --patience 5 \
304
+ --weight_decay 1e-5 \
305
+ --gradient_clip_val 0.5 \
306
+ --use_mixed_precision \
307
+ --test_every_n 5
308
+ ```
309
+
310
+ ## πŸ“ˆ Monitoring and Logging
311
+
312
+ ### TensorBoard Integration
313
+
314
+ Training automatically logs metrics to TensorBoard:
315
+
316
+ ```bash
317
+ tensorboard --logdir runs/
318
+ ```
319
+
320
+ Available metrics:
321
+ - Training/validation loss
322
+ - Learning rate
323
+ - Test performance across conditions
324
+ - Error visualizations
325
+ - Model hyperparameters
326
+
327
+ ### Log Files
328
+
329
+ Training logs are saved to:
330
+ - `logs/training_{exp_id}.log`: Python logging output
331
+ - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
332
+
333
+ ## πŸ§ͺ Testing and Evaluation
334
+
335
+ ### Automatic Testing
336
+
337
+ The training pipeline automatically evaluates models on:
338
+ - **DS (Delay Spread)**: Varying delay spread conditions
339
+ - **SNR**: Different signal-to-noise ratios
340
+ - **MDS (Multi-Doppler)**: Various Doppler shift scenarios
341
+
342
+ ### Manual Evaluation
343
+
344
+ ```python
345
+ from src.models import AdaFortiTranEstimator
346
+ from src.config import load_config
347
+
348
+ # Load configurations
349
+ system_config, model_config = load_config(
350
+ 'config/system_config.yaml',
351
+ 'config/adafortitran.yaml'
352
+ )
353
+
354
+ # Initialize model
355
+ model = AdaFortiTranEstimator(system_config, model_config)
356
+
357
+ # Load checkpoint
358
+ checkpoint = torch.load('checkpoint.pt')
359
+ model.load_state_dict(checkpoint['model_state_dict'])
360
+
361
+ # Evaluate
362
+ model.eval()
363
+ # ... evaluation code
364
+ ```
365
+
366
+ ## πŸ”¬ Research and Development
367
+
368
+ ### Adding Custom Callbacks
369
+
370
+ ```python
371
+ from src.main.trainer import Callback, TrainingMetrics
372
+
373
+ class CustomCallback(Callback):
374
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
375
+ # Custom logic here
376
+ print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
377
+ ```
378
+
379
+ ### Extending Models
380
+
381
+ The modular architecture makes it easy to add new model variants:
382
+
383
+ ```python
384
+ from src.models.fortitran import BaseFortiTranEstimator
385
+
386
+ class CustomEstimator(BaseFortiTranEstimator):
387
+ def __init__(self, system_config, model_config):
388
+ super().__init__(system_config, model_config, use_channel_adaptation=True)
389
+ # Add custom components
390
+ ```
391
+
392
+ ## πŸ› Troubleshooting
393
+
394
+ ### Common Issues
395
+
396
+ **CUDA Out of Memory**:
397
+ - Reduce batch size: `--batch_size 32`
398
+ - Enable mixed precision: `--use_mixed_precision`
399
+ - Reduce number of workers: `--num_workers 2`
400
+
401
+ **Slow Training**:
402
+ - Increase number of workers: `--num_workers 8`
403
+ - Enable pin memory: `--pin_memory`
404
+ - Use mixed precision: `--use_mixed_precision`
405
+
406
+ **Poor Convergence**:
407
+ - Adjust learning rate: `--lr 1e-4`
408
+ - Add gradient clipping: `--gradient_clip_val 1.0`
409
+ - Increase patience: `--patience 10`
410
+
411
+ ### Getting Help
412
+
413
+ 1. Check the logs in `logs/training_{exp_id}.log`
414
+ 2. Verify dataset format matches requirements
415
+ 3. Ensure all dependencies are installed correctly
416
+ 4. Check TensorBoard for training curves
417
+
418
+ ## πŸ“š Citation
419
+
420
+ If you use this code in your research, please cite:
421
+
422
+ ```bibtex
423
+ @misc{guler2025adafortitranadaptivetransformermodel,
424
+ title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
425
+ author={Berkay Guler and Hamid Jafarkhani},
426
+ year={2025},
427
+ eprint={2505.09076},
428
+ archivePrefix={arXiv},
429
+ primaryClass={cs.LG},
430
+ url={https://arxiv.org/abs/2505.09076},
431
+ }
432
+ ```
433
+
434
+ ## πŸ“„ License
435
+
436
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
437
+
438
+ Copyright (c) 2025 [Berkay Guler/University of California, Irvine]