AdaFortiTran / README_original.md
BerkIGuler's picture
Initial commit for Hugging Face
7e105b2
# AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
[![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
[![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
## πŸ“– Overview
AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
### Key Features
- **πŸ”„ Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
- **⚑ High Performance**: State-of-the-art results on OFDM channel estimation tasks
- **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
- **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
- **πŸš€ Production Ready**: Comprehensive training pipeline with advanced features
## πŸ—οΈ Architecture
The project implements three model variants:
1. **Linear Estimator**: Simple learned linear transformation baseline
2. **FortiTran**: Fixed transformer-based channel estimator
3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
### Model Comparison
| Model | Channel Adaptation | Complexity | Performance |
|-------|-------------------|------------|-------------|
| Linear | ❌ | Low | Baseline |
| FortiTran | ❌ | Medium | Good |
| AdaFortiTran | βœ… | High | **Best** |
## πŸš€ Quick Start
### Installation
1. **Clone the repository**:
```bash
git clone https://github.com/your-username/AdaFortiTran.git
cd AdaFortiTran
```
2. **Install dependencies**:
```bash
pip install -r requirements.txt
```
3. **Verify installation**:
```bash
python -c "import torch; print(f'PyTorch {torch.__version__}')"
```
### Basic Training
Train an AdaFortiTran model with default settings:
```bash
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id my_experiment
```
### Advanced Training
Use all available features for optimal performance:
```bash
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id advanced_experiment \
--batch_size 128 \
--lr 5e-4 \
--max_epoch 100 \
--patience 10 \
--weight_decay 1e-4 \
--gradient_clip_val 1.0 \
--use_mixed_precision \
--save_every_n_epochs 5 \
--num_workers 8 \
--test_every_n 5
```
## πŸ“ Project Structure
```
AdaFortiTran/
β”œβ”€β”€ config/ # Configuration files
β”‚ β”œβ”€β”€ system_config.yaml # OFDM system parameters
β”‚ β”œβ”€β”€ adafortitran.yaml # AdaFortiTran model config
β”‚ β”œβ”€β”€ fortitran.yaml # FortiTran model config
β”‚ └── linear.yaml # Linear model config
β”œβ”€β”€ data/ # Dataset directory
β”‚ β”œβ”€β”€ train/ # Training data
β”‚ β”œβ”€β”€ val/ # Validation data
β”‚ └── test/ # Test data (DS, MDS, SNR sets)
β”œβ”€β”€ src/ # Source code
β”‚ β”œβ”€β”€ main/ # Training pipeline
β”‚ β”‚ β”œβ”€β”€ trainer.py # Enhanced ModelTrainer
β”‚ β”‚ └── parser.py # Command-line argument parser
β”‚ β”œβ”€β”€ models/ # Model implementations
β”‚ β”‚ β”œβ”€β”€ adafortitran.py # AdaFortiTran model
β”‚ β”‚ β”œβ”€β”€ fortitran.py # FortiTran model
β”‚ β”‚ β”œβ”€β”€ linear.py # Linear model
β”‚ β”‚ └── blocks/ # Model building blocks
β”‚ β”œβ”€β”€ data/ # Data loading
β”‚ β”‚ └── dataset.py # Dataset and DataLoader classes
β”‚ β”œβ”€β”€ config/ # Configuration management
β”‚ β”‚ β”œβ”€β”€ config_loader.py # YAML configuration loader
β”‚ β”‚ └── schemas.py # Pydantic validation schemas
β”‚ └── utils.py # Utility functions
β”œβ”€β”€ requirements.txt # Python dependencies
β”œβ”€β”€ README.md # This file
```
## βš™οΈ Configuration
### System Configuration (`config/system_config.yaml`)
Defines OFDM system parameters:
```yaml
ofdm:
num_scs: 120 # Number of subcarriers
num_symbols: 14 # Number of OFDM symbols
pilot:
num_scs: 12 # Number of pilot subcarriers
num_symbols: 2 # Number of pilot symbols
```
### Model Configuration (`config/adafortitran.yaml`)
Defines model architecture parameters:
```yaml
model_type: 'adafortitran'
patch_size: [3, 2] # Patch dimensions
num_layers: 6 # Transformer layers
model_dim: 128 # Model dimension
num_head: 4 # Attention heads
activation: 'gelu' # Activation function
dropout: 0.1 # Dropout rate
max_seq_len: 512 # Maximum sequence length
pos_encoding_type: 'learnable' # Positional encoding
channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
adaptive_token_length: 6 # Adaptive token length
```
## 🎯 Training Features
### Advanced Training Options
| Feature | Description | Default |
|---------|-------------|---------|
| `--use_mixed_precision` | Enable mixed precision training | False |
| `--gradient_clip_val` | Gradient clipping value | None |
| `--weight_decay` | Weight decay for optimizer | 0.0 |
| `--save_checkpoints` | Enable model checkpointing | True |
| `--save_best_only` | Save only best model | True |
| `--resume_from_checkpoint` | Resume from checkpoint | None |
| `--num_workers` | Data loading workers | 4 |
| `--pin_memory` | Pin memory for GPU | True |
### Callback System
The training pipeline includes an extensible callback system:
- **TensorBoard Logging**: Automatic metric tracking and visualization
- **Checkpoint Management**: Flexible checkpoint saving strategies
- **Custom Callbacks**: Easy to add new logging or monitoring systems
### Performance Optimizations
- **Mixed Precision Training**: Faster training on modern GPUs
- **Optimized Data Loading**: Configurable workers and memory pinning
- **Gradient Clipping**: Stable training with configurable clipping
- **Early Stopping**: Automatic training termination on plateau
## πŸ“Š Dataset Format
### Expected File Structure
```
data/
β”œβ”€β”€ train/
β”‚ β”œβ”€β”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β”‚ β”œβ”€β”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β”‚ └── ...
β”œβ”€β”€ val/
β”‚ └── ...
└── test/
β”œβ”€β”€ DS_test_set/ # Delay Spread tests
β”‚ β”œβ”€β”€ DS_50/
β”‚ β”œβ”€β”€ DS_100/
β”‚ └── ...
β”œβ”€β”€ SNR_test_set/ # SNR tests
β”‚ β”œβ”€β”€ SNR_10/
β”‚ β”œβ”€β”€ SNR_20/
β”‚ └── ...
└── MDS_test_set/ # Multi-Doppler tests
β”œβ”€β”€ DOP_200/
β”œβ”€β”€ DOP_400/
└── ...
```
### File Naming Convention
Files must follow the pattern:
```
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
```
Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
### Data Format
Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
- `H[:, :, 0]`: Ground truth channel (complex values)
- `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
- `H[:, :, 2]`: Reserved for future use
## πŸ”§ Usage Examples
### Training Different Models
**Linear Estimator**:
```bash
python src/main.py \
--model_name linear \
--system_config_path config/system_config.yaml \
--model_config_path config/linear.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id linear_baseline
```
**FortiTran**:
```bash
python src/main.py \
--model_name fortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/fortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id fortitran_experiment
```
**AdaFortiTran**:
```bash
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id adafortitran_experiment
```
### Resume Training
```bash
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id resumed_experiment \
--resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
```
### Hyperparameter Tuning
```bash
python src/main.py \
--model_name adafortitran \
--system_config_path config/system_config.yaml \
--model_config_path config/adafortitran.yaml \
--train_set data/train \
--val_set data/val \
--test_set data/test \
--exp_id hyperparameter_tuning \
--batch_size 64 \
--lr 1e-3 \
--max_epoch 50 \
--patience 5 \
--weight_decay 1e-5 \
--gradient_clip_val 0.5 \
--use_mixed_precision \
--test_every_n 5
```
## πŸ“ˆ Monitoring and Logging
### TensorBoard Integration
Training automatically logs metrics to TensorBoard:
```bash
tensorboard --logdir runs/
```
Available metrics:
- Training/validation loss
- Learning rate
- Test performance across conditions
- Error visualizations
- Model hyperparameters
### Log Files
Training logs are saved to:
- `logs/training_{exp_id}.log`: Python logging output
- `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
## πŸ§ͺ Testing and Evaluation
### Automatic Testing
The training pipeline automatically evaluates models on:
- **DS (Delay Spread)**: Varying delay spread conditions
- **SNR**: Different signal-to-noise ratios
- **MDS (Multi-Doppler)**: Various Doppler shift scenarios
### Manual Evaluation
```python
from src.models import AdaFortiTranEstimator
from src.config import load_config
# Load configurations
system_config, model_config = load_config(
'config/system_config.yaml',
'config/adafortitran.yaml'
)
# Initialize model
model = AdaFortiTranEstimator(system_config, model_config)
# Load checkpoint
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# Evaluate
model.eval()
# ... evaluation code
```
## πŸ”¬ Research and Development
### Adding Custom Callbacks
```python
from src.main.trainer import Callback, TrainingMetrics
class CustomCallback(Callback):
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
# Custom logic here
print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
```
### Extending Models
The modular architecture makes it easy to add new model variants:
```python
from src.models.fortitran import BaseFortiTranEstimator
class CustomEstimator(BaseFortiTranEstimator):
def __init__(self, system_config, model_config):
super().__init__(system_config, model_config, use_channel_adaptation=True)
# Add custom components
```
## πŸ› Troubleshooting
### Common Issues
**CUDA Out of Memory**:
- Reduce batch size: `--batch_size 32`
- Enable mixed precision: `--use_mixed_precision`
- Reduce number of workers: `--num_workers 2`
**Slow Training**:
- Increase number of workers: `--num_workers 8`
- Enable pin memory: `--pin_memory`
- Use mixed precision: `--use_mixed_precision`
**Poor Convergence**:
- Adjust learning rate: `--lr 1e-4`
- Add gradient clipping: `--gradient_clip_val 1.0`
- Increase patience: `--patience 10`
### Getting Help
1. Check the logs in `logs/training_{exp_id}.log`
2. Verify dataset format matches requirements
3. Ensure all dependencies are installed correctly
4. Check TensorBoard for training curves
## πŸ“š Citation
If you use this code in your research, please cite:
```bibtex
@misc{guler2025adafortitranadaptivetransformermodel,
title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
author={Berkay Guler and Hamid Jafarkhani},
year={2025},
eprint={2505.09076},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.09076},
}
```
## πŸ“„ License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
Copyright (c) 2025 [Berkay Guler/University of California, Irvine]