# AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/) [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/) Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada. ## ๐Ÿ“– Overview AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments. ### Key Features - **๐Ÿ”„ Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information - **โšก High Performance**: State-of-the-art results on OFDM channel estimation tasks - **๐Ÿง  Transformer-Based**: Leverages attention mechanisms for long-range dependencies - **๐ŸŽฏ Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions - **๐Ÿš€ Production Ready**: Comprehensive training pipeline with advanced features ## ๐Ÿ—๏ธ Architecture The project implements three model variants: 1. **Linear Estimator**: Simple learned linear transformation baseline 2. **FortiTran**: Fixed transformer-based channel estimator 3. **AdaFortiTran**: Adaptive transformer with channel condition awareness ### Model Comparison | Model | Channel Adaptation | Complexity | Performance | |-------|-------------------|------------|-------------| | Linear | โŒ | Low | Baseline | | FortiTran | โŒ | Medium | Good | | AdaFortiTran | โœ… | High | **Best** | ## ๐Ÿš€ Quick Start ### Installation 1. **Clone the repository**: ```bash git clone https://github.com/your-username/AdaFortiTran.git cd AdaFortiTran ``` 2. **Install dependencies**: ```bash pip install -r requirements.txt ``` 3. **Verify installation**: ```bash python -c "import torch; print(f'PyTorch {torch.__version__}')" ``` ### Basic Training Train an AdaFortiTran model with default settings: ```bash python src/main.py \ --model_name adafortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/adafortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id my_experiment ``` ### Advanced Training Use all available features for optimal performance: ```bash python src/main.py \ --model_name adafortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/adafortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id advanced_experiment \ --batch_size 128 \ --lr 5e-4 \ --max_epoch 100 \ --patience 10 \ --weight_decay 1e-4 \ --gradient_clip_val 1.0 \ --use_mixed_precision \ --save_every_n_epochs 5 \ --num_workers 8 \ --test_every_n 5 ``` ## ๐Ÿ“ Project Structure ``` AdaFortiTran/ โ”œโ”€โ”€ config/ # Configuration files โ”‚ โ”œโ”€โ”€ system_config.yaml # OFDM system parameters โ”‚ โ”œโ”€โ”€ adafortitran.yaml # AdaFortiTran model config โ”‚ โ”œโ”€โ”€ fortitran.yaml # FortiTran model config โ”‚ โ””โ”€โ”€ linear.yaml # Linear model config โ”œโ”€โ”€ data/ # Dataset directory โ”‚ โ”œโ”€โ”€ train/ # Training data โ”‚ โ”œโ”€โ”€ val/ # Validation data โ”‚ โ””โ”€โ”€ test/ # Test data (DS, MDS, SNR sets) โ”œโ”€โ”€ src/ # Source code โ”‚ โ”œโ”€โ”€ main/ # Training pipeline โ”‚ โ”‚ โ”œโ”€โ”€ trainer.py # Enhanced ModelTrainer โ”‚ โ”‚ โ””โ”€โ”€ parser.py # Command-line argument parser โ”‚ โ”œโ”€โ”€ models/ # Model implementations โ”‚ โ”‚ โ”œโ”€โ”€ adafortitran.py # AdaFortiTran model โ”‚ โ”‚ โ”œโ”€โ”€ fortitran.py # FortiTran model โ”‚ โ”‚ โ”œโ”€โ”€ linear.py # Linear model โ”‚ โ”‚ โ””โ”€โ”€ blocks/ # Model building blocks โ”‚ โ”œโ”€โ”€ data/ # Data loading โ”‚ โ”‚ โ””โ”€โ”€ dataset.py # Dataset and DataLoader classes โ”‚ โ”œโ”€โ”€ config/ # Configuration management โ”‚ โ”‚ โ”œโ”€โ”€ config_loader.py # YAML configuration loader โ”‚ โ”‚ โ””โ”€โ”€ schemas.py # Pydantic validation schemas โ”‚ โ””โ”€โ”€ utils.py # Utility functions โ”œโ”€โ”€ requirements.txt # Python dependencies โ”œโ”€โ”€ README.md # This file ``` ## โš™๏ธ Configuration ### System Configuration (`config/system_config.yaml`) Defines OFDM system parameters: ```yaml ofdm: num_scs: 120 # Number of subcarriers num_symbols: 14 # Number of OFDM symbols pilot: num_scs: 12 # Number of pilot subcarriers num_symbols: 2 # Number of pilot symbols ``` ### Model Configuration (`config/adafortitran.yaml`) Defines model architecture parameters: ```yaml model_type: 'adafortitran' patch_size: [3, 2] # Patch dimensions num_layers: 6 # Transformer layers model_dim: 128 # Model dimension num_head: 4 # Attention heads activation: 'gelu' # Activation function dropout: 0.1 # Dropout rate max_seq_len: 512 # Maximum sequence length pos_encoding_type: 'learnable' # Positional encoding channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers adaptive_token_length: 6 # Adaptive token length ``` ## ๐ŸŽฏ Training Features ### Advanced Training Options | Feature | Description | Default | |---------|-------------|---------| | `--use_mixed_precision` | Enable mixed precision training | False | | `--gradient_clip_val` | Gradient clipping value | None | | `--weight_decay` | Weight decay for optimizer | 0.0 | | `--save_checkpoints` | Enable model checkpointing | True | | `--save_best_only` | Save only best model | True | | `--resume_from_checkpoint` | Resume from checkpoint | None | | `--num_workers` | Data loading workers | 4 | | `--pin_memory` | Pin memory for GPU | True | ### Callback System The training pipeline includes an extensible callback system: - **TensorBoard Logging**: Automatic metric tracking and visualization - **Checkpoint Management**: Flexible checkpoint saving strategies - **Custom Callbacks**: Easy to add new logging or monitoring systems ### Performance Optimizations - **Mixed Precision Training**: Faster training on modern GPUs - **Optimized Data Loading**: Configurable workers and memory pinning - **Gradient Clipping**: Stable training with configurable clipping - **Early Stopping**: Automatic training termination on plateau ## ๐Ÿ“Š Dataset Format ### Expected File Structure ``` data/ โ”œโ”€โ”€ train/ โ”‚ โ”œโ”€โ”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat โ”‚ โ”œโ”€โ”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat โ”‚ โ””โ”€โ”€ ... โ”œโ”€โ”€ val/ โ”‚ โ””โ”€โ”€ ... โ””โ”€โ”€ test/ โ”œโ”€โ”€ DS_test_set/ # Delay Spread tests โ”‚ โ”œโ”€โ”€ DS_50/ โ”‚ โ”œโ”€โ”€ DS_100/ โ”‚ โ””โ”€โ”€ ... โ”œโ”€โ”€ SNR_test_set/ # SNR tests โ”‚ โ”œโ”€โ”€ SNR_10/ โ”‚ โ”œโ”€โ”€ SNR_20/ โ”‚ โ””โ”€โ”€ ... โ””โ”€โ”€ MDS_test_set/ # Multi-Doppler tests โ”œโ”€โ”€ DOP_200/ โ”œโ”€โ”€ DOP_400/ โ””โ”€โ”€ ... ``` ### File Naming Convention Files must follow the pattern: ``` {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat ``` Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat` ### Data Format Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`: - `H[:, :, 0]`: Ground truth channel (complex values) - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions - `H[:, :, 2]`: Reserved for future use ## ๐Ÿ”ง Usage Examples ### Training Different Models **Linear Estimator**: ```bash python src/main.py \ --model_name linear \ --system_config_path config/system_config.yaml \ --model_config_path config/linear.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id linear_baseline ``` **FortiTran**: ```bash python src/main.py \ --model_name fortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/fortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id fortitran_experiment ``` **AdaFortiTran**: ```bash python src/main.py \ --model_name adafortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/adafortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id adafortitran_experiment ``` ### Resume Training ```bash python src/main.py \ --model_name adafortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/adafortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id resumed_experiment \ --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt ``` ### Hyperparameter Tuning ```bash python src/main.py \ --model_name adafortitran \ --system_config_path config/system_config.yaml \ --model_config_path config/adafortitran.yaml \ --train_set data/train \ --val_set data/val \ --test_set data/test \ --exp_id hyperparameter_tuning \ --batch_size 64 \ --lr 1e-3 \ --max_epoch 50 \ --patience 5 \ --weight_decay 1e-5 \ --gradient_clip_val 0.5 \ --use_mixed_precision \ --test_every_n 5 ``` ## ๐Ÿ“ˆ Monitoring and Logging ### TensorBoard Integration Training automatically logs metrics to TensorBoard: ```bash tensorboard --logdir runs/ ``` Available metrics: - Training/validation loss - Learning rate - Test performance across conditions - Error visualizations - Model hyperparameters ### Log Files Training logs are saved to: - `logs/training_{exp_id}.log`: Python logging output - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints ## ๐Ÿงช Testing and Evaluation ### Automatic Testing The training pipeline automatically evaluates models on: - **DS (Delay Spread)**: Varying delay spread conditions - **SNR**: Different signal-to-noise ratios - **MDS (Multi-Doppler)**: Various Doppler shift scenarios ### Manual Evaluation ```python from src.models import AdaFortiTranEstimator from src.config import load_config # Load configurations system_config, model_config = load_config( 'config/system_config.yaml', 'config/adafortitran.yaml' ) # Initialize model model = AdaFortiTranEstimator(system_config, model_config) # Load checkpoint checkpoint = torch.load('checkpoint.pt') model.load_state_dict(checkpoint['model_state_dict']) # Evaluate model.eval() # ... evaluation code ``` ## ๐Ÿ”ฌ Research and Development ### Adding Custom Callbacks ```python from src.main.trainer import Callback, TrainingMetrics class CustomCallback(Callback): def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None: # Custom logic here print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}") ``` ### Extending Models The modular architecture makes it easy to add new model variants: ```python from src.models.fortitran import BaseFortiTranEstimator class CustomEstimator(BaseFortiTranEstimator): def __init__(self, system_config, model_config): super().__init__(system_config, model_config, use_channel_adaptation=True) # Add custom components ``` ## ๐Ÿ› Troubleshooting ### Common Issues **CUDA Out of Memory**: - Reduce batch size: `--batch_size 32` - Enable mixed precision: `--use_mixed_precision` - Reduce number of workers: `--num_workers 2` **Slow Training**: - Increase number of workers: `--num_workers 8` - Enable pin memory: `--pin_memory` - Use mixed precision: `--use_mixed_precision` **Poor Convergence**: - Adjust learning rate: `--lr 1e-4` - Add gradient clipping: `--gradient_clip_val 1.0` - Increase patience: `--patience 10` ### Getting Help 1. Check the logs in `logs/training_{exp_id}.log` 2. Verify dataset format matches requirements 3. Ensure all dependencies are installed correctly 4. Check TensorBoard for training curves ## ๐Ÿ“š Citation If you use this code in your research, please cite: ```bibtex @misc{guler2025adafortitranadaptivetransformermodel, title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation}, author={Berkay Guler and Hamid Jafarkhani}, year={2025}, eprint={2505.09076}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2505.09076}, } ``` ## ๐Ÿ“„ License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. Copyright (c) 2025 [Berkay Guler/University of California, Irvine]