AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
Official implementation of AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation accepted at ICC 2025, Montreal, Canada.
π Overview
AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
Key Features
- π Adaptive Architecture: Dynamically adapts to channel conditions using meta-information
- β‘ High Performance: State-of-the-art results on OFDM channel estimation tasks
- π§ Transformer-Based: Leverages attention mechanisms for long-range dependencies
- π― Robust: Maintains performance across varying SNR, delay spread, and Doppler conditions
- π Production Ready: Comprehensive training pipeline with advanced features
ποΈ Architecture
The project implements three model variants:
- Linear Estimator: Simple learned linear transformation baseline
- FortiTran: Fixed transformer-based channel estimator
- AdaFortiTran: Adaptive transformer with channel condition awareness
Model Comparison
| Model | Channel Adaptation | Complexity | Performance |
|---|---|---|---|
| Linear | β | Low | Baseline |
| FortiTran | β | Medium | Good |
| AdaFortiTran | β | High | Best |
π Quick Start
Installation
Clone the repository:
git clone https://github.com/your-username/AdaFortiTran.git cd AdaFortiTranInstall dependencies:
pip install -r requirements.txtVerify installation:
python -c "import torch; print(f'PyTorch {torch.__version__}')"
Basic Training
Train an AdaFortiTran model with default settings:
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id my_experiment
Advanced Training
Use all available features for optimal performance:
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id advanced_experiment \
--batch_size 128 \
--lr 5e-4 \
--max_epoch 100 \
--patience 10 \
--weight_decay 1e-4 \
--gradient_clip_val 1.0 \
--use_mixed_precision \
--save_every_n_epochs 5 \
--num_workers 8 \
--test_every_n 5
π Project Structure
AdaFortiTran/
βββ config/ # Configuration files
β βββ system_config.yaml # OFDM system parameters
β βββ adafortitran.yaml # AdaFortiTran model config
β βββ fortitran.yaml # FortiTran model config
β βββ linear.yaml # Linear model config
βββ data/ # Dataset directory
β βββ train/ # Training data
β βββ val/ # Validation data
β βββ test/ # Test data (DS, MDS, SNR sets)
βββ src/ # Source code
β βββ main/ # Training pipeline
β β βββ trainer.py # Enhanced ModelTrainer
β β βββ parser.py # Command-line argument parser
β βββ models/ # Model implementations
β β βββ adafortitran.py # AdaFortiTran model
β β βββ fortitran.py # FortiTran model
β β βββ linear.py # Linear model
β β βββ blocks/ # Model building blocks
β βββ data/ # Data loading
β β βββ dataset.py # Dataset and DataLoader classes
β βββ config/ # Configuration management
β β βββ config_loader.py # YAML configuration loader
β β βββ schemas.py # Pydantic validation schemas
β βββ utils.py # Utility functions
βββ requirements.txt # Python dependencies
βββ README.md # This file
βοΈ Configuration
System Configuration (config/system_config.yaml)
Defines OFDM system parameters:
ofdm:
num_scs: 120 # Number of subcarriers
num_symbols: 14 # Number of OFDM symbols
pilot:
num_scs: 12 # Number of pilot subcarriers
num_symbols: 2 # Number of pilot symbols
Model Configuration (config/adafortitran.yaml)
Defines model architecture parameters:
model_type: 'adafortitran'
patch_size: [3, 2] # Patch dimensions
num_layers: 6 # Transformer layers
model_dim: 128 # Model dimension
num_head: 4 # Attention heads
activation: 'gelu' # Activation function
dropout: 0.1 # Dropout rate
max_seq_len: 512 # Maximum sequence length
pos_encoding_type: 'learnable' # Positional encoding
channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
adaptive_token_length: 6 # Adaptive token length
π― Training Features
Advanced Training Options
| Feature | Description | Default |
|---|---|---|
--use_mixed_precision |
Enable mixed precision training | False |
--gradient_clip_val |
Gradient clipping value | None |
--weight_decay |
Weight decay for optimizer | 0.0 |
--save_checkpoints |
Enable model checkpointing | True |
--save_best_only |
Save only best model | True |
--resume_from_checkpoint |
Resume from checkpoint | None |
--num_workers |
Data loading workers | 4 |
--pin_memory |
Pin memory for GPU | True |
Callback System
The training pipeline includes an extensible callback system:
- TensorBoard Logging: Automatic metric tracking and visualization
- Checkpoint Management: Flexible checkpoint saving strategies
- Custom Callbacks: Easy to add new logging or monitoring systems
Performance Optimizations
- Mixed Precision Training: Faster training on modern GPUs
- Optimized Data Loading: Configurable workers and memory pinning
- Gradient Clipping: Stable training with configurable clipping
- Early Stopping: Automatic training termination on plateau
π Dataset Format
Expected File Structure
data/
βββ train/
β βββ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β βββ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β βββ ...
βββ val/
β βββ ...
βββ test/
βββ DS_test_set/ # Delay Spread tests
β βββ DS_50/
β βββ DS_100/
β βββ ...
βββ SNR_test_set/ # SNR tests
β βββ SNR_10/
β βββ SNR_20/
β βββ ...
βββ MDS_test_set/ # Multi-Doppler tests
βββ DOP_200/
βββ DOP_400/
βββ ...
File Naming Convention
Files must follow the pattern:
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
Data Format
Each .mat file must contain variable H with shape [subcarriers, symbols, 3]:
H[:, :, 0]: Ground truth channel (complex values)H[:, :, 1]: LS channel estimate with zeros for non-pilot positionsH[:, :, 2]: Reserved for future use
π§ Usage Examples
Training Different Models
Linear Estimator:
python src/main.py \
--model_name linear \
--system_config_path config/system_config.yaml \
--model_config_path config/linear.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id linear_baseline
FortiTran:
python src/main.py \
--model_name fortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/fortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id fortitran_experiment
AdaFortiTran:
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id adafortitran_experiment
Resume Training
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id resumed_experiment \
--resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
Hyperparameter Tuning
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id hyperparameter_tuning \
--batch_size 64 \
--lr 1e-3 \
--max_epoch 50 \
--patience 5 \
--weight_decay 1e-5 \
--gradient_clip_val 0.5 \
--use_mixed_precision \
--test_every_n 5
π Monitoring and Logging
TensorBoard Integration
Training automatically logs metrics to TensorBoard:
tensorboard --logdir runs/
Available metrics:
- Training/validation loss
- Learning rate
- Test performance across conditions
- Error visualizations
- Model hyperparameters
Log Files
Training logs are saved to:
logs/training_{exp_id}.log: Python logging outputruns/{model_name}_{exp_id}/: TensorBoard logs and checkpoints
π§ͺ Testing and Evaluation
Automatic Testing
The training pipeline automatically evaluates models on:
- DS (Delay Spread): Varying delay spread conditions
- SNR: Different signal-to-noise ratios
- MDS (Multi-Doppler): Various Doppler shift scenarios
Manual Evaluation
from src.models import AdaFortiTranEstimator
from src.config import load_config
# Load configurations
system_config, model_config = load_config(
'config/system_config.yaml',
'config/adafortitran.yaml'
)
# Initialize model
model = AdaFortiTranEstimator(system_config, model_config)
# Load checkpoint
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# Evaluate
model.eval()
# ... evaluation code
π¬ Research and Development
Adding Custom Callbacks
from src.main.trainer import Callback, TrainingMetrics
class CustomCallback(Callback):
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
# Custom logic here
print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
Extending Models
The modular architecture makes it easy to add new model variants:
from src.models.fortitran import BaseFortiTranEstimator
class CustomEstimator(BaseFortiTranEstimator):
def __init__(self, system_config, model_config):
super().__init__(system_config, model_config, use_channel_adaptation=True)
# Add custom components
π Troubleshooting
Common Issues
CUDA Out of Memory:
- Reduce batch size:
--batch_size 32 - Enable mixed precision:
--use_mixed_precision - Reduce number of workers:
--num_workers 2
Slow Training:
- Increase number of workers:
--num_workers 8 - Enable pin memory:
--pin_memory - Use mixed precision:
--use_mixed_precision
Poor Convergence:
- Adjust learning rate:
--lr 1e-4 - Add gradient clipping:
--gradient_clip_val 1.0 - Increase patience:
--patience 10
Getting Help
- Check the logs in
logs/training_{exp_id}.log - Verify dataset format matches requirements
- Ensure all dependencies are installed correctly
- Check TensorBoard for training curves
π Citation
If you use this code in your research, please cite:
@misc{guler2025adafortitranadaptivetransformermodel,
title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
author={Berkay Guler and Hamid Jafarkhani},
year={2025},
eprint={2505.09076},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.09076},
}
π License
This project is licensed under the MIT License - see the LICENSE file for details.
Copyright (c) 2025 [Berkay Guler/University of California, Irvine]