AdaFortiTran / README_original.md
BerkIGuler's picture
Initial commit for Hugging Face
7e105b2

AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation

License Python PyTorch

Official implementation of AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation accepted at ICC 2025, Montreal, Canada.

πŸ“– Overview

AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.

Key Features

  • πŸ”„ Adaptive Architecture: Dynamically adapts to channel conditions using meta-information
  • ⚑ High Performance: State-of-the-art results on OFDM channel estimation tasks
  • 🧠 Transformer-Based: Leverages attention mechanisms for long-range dependencies
  • 🎯 Robust: Maintains performance across varying SNR, delay spread, and Doppler conditions
  • πŸš€ Production Ready: Comprehensive training pipeline with advanced features

πŸ—οΈ Architecture

The project implements three model variants:

  1. Linear Estimator: Simple learned linear transformation baseline
  2. FortiTran: Fixed transformer-based channel estimator
  3. AdaFortiTran: Adaptive transformer with channel condition awareness

Model Comparison

Model Channel Adaptation Complexity Performance
Linear ❌ Low Baseline
FortiTran ❌ Medium Good
AdaFortiTran βœ… High Best

πŸš€ Quick Start

Installation

  1. Clone the repository:

    git clone https://github.com/your-username/AdaFortiTran.git
    cd AdaFortiTran
    
  2. Install dependencies:

    pip install -r requirements.txt
    
  3. Verify installation:

    python -c "import torch; print(f'PyTorch {torch.__version__}')"
    

Basic Training

Train an AdaFortiTran model with default settings:

python src/main.py \
    --model_name adafortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/adafortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id my_experiment

Advanced Training

Use all available features for optimal performance:

python src/main.py \
    --model_name adafortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/adafortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id advanced_experiment \
    --batch_size 128 \
    --lr 5e-4 \
    --max_epoch 100 \
    --patience 10 \
    --weight_decay 1e-4 \
    --gradient_clip_val 1.0 \
    --use_mixed_precision \
    --save_every_n_epochs 5 \
    --num_workers 8 \
    --test_every_n 5

πŸ“ Project Structure

AdaFortiTran/
β”œβ”€β”€ config/                     # Configuration files
β”‚   β”œβ”€β”€ system_config.yaml     # OFDM system parameters
β”‚   β”œβ”€β”€ adafortitran.yaml      # AdaFortiTran model config
β”‚   β”œβ”€β”€ fortitran.yaml         # FortiTran model config
β”‚   └── linear.yaml            # Linear model config
β”œβ”€β”€ data/                      # Dataset directory
β”‚   β”œβ”€β”€ train/                 # Training data
β”‚   β”œβ”€β”€ val/                   # Validation data
β”‚   └── test/                  # Test data (DS, MDS, SNR sets)
β”œβ”€β”€ src/                       # Source code
β”‚   β”œβ”€β”€ main/                  # Training pipeline
β”‚   β”‚   β”œβ”€β”€ trainer.py         # Enhanced ModelTrainer
β”‚   β”‚   └── parser.py          # Command-line argument parser
β”‚   β”œβ”€β”€ models/                # Model implementations
β”‚   β”‚   β”œβ”€β”€ adafortitran.py    # AdaFortiTran model
β”‚   β”‚   β”œβ”€β”€ fortitran.py       # FortiTran model
β”‚   β”‚   β”œβ”€β”€ linear.py          # Linear model
β”‚   β”‚   └── blocks/            # Model building blocks
β”‚   β”œβ”€β”€ data/                  # Data loading
β”‚   β”‚   └── dataset.py         # Dataset and DataLoader classes
β”‚   β”œβ”€β”€ config/                # Configuration management
β”‚   β”‚   β”œβ”€β”€ config_loader.py   # YAML configuration loader
β”‚   β”‚   └── schemas.py         # Pydantic validation schemas
β”‚   └── utils.py               # Utility functions
β”œβ”€β”€ requirements.txt           # Python dependencies
β”œβ”€β”€ README.md                  # This file

βš™οΈ Configuration

System Configuration (config/system_config.yaml)

Defines OFDM system parameters:

ofdm:
  num_scs: 120      # Number of subcarriers
  num_symbols: 14   # Number of OFDM symbols

pilot:
  num_scs: 12       # Number of pilot subcarriers
  num_symbols: 2    # Number of pilot symbols

Model Configuration (config/adafortitran.yaml)

Defines model architecture parameters:

model_type: 'adafortitran'
patch_size: [3, 2]                    # Patch dimensions
num_layers: 6                         # Transformer layers
model_dim: 128                        # Model dimension
num_head: 4                           # Attention heads
activation: 'gelu'                    # Activation function
dropout: 0.1                          # Dropout rate
max_seq_len: 512                      # Maximum sequence length
pos_encoding_type: 'learnable'        # Positional encoding
channel_adaptivity_hidden_sizes: [7, 42, 560]  # Adaptation layers
adaptive_token_length: 6              # Adaptive token length

🎯 Training Features

Advanced Training Options

Feature Description Default
--use_mixed_precision Enable mixed precision training False
--gradient_clip_val Gradient clipping value None
--weight_decay Weight decay for optimizer 0.0
--save_checkpoints Enable model checkpointing True
--save_best_only Save only best model True
--resume_from_checkpoint Resume from checkpoint None
--num_workers Data loading workers 4
--pin_memory Pin memory for GPU True

Callback System

The training pipeline includes an extensible callback system:

  • TensorBoard Logging: Automatic metric tracking and visualization
  • Checkpoint Management: Flexible checkpoint saving strategies
  • Custom Callbacks: Easy to add new logging or monitoring systems

Performance Optimizations

  • Mixed Precision Training: Faster training on modern GPUs
  • Optimized Data Loading: Configurable workers and memory pinning
  • Gradient Clipping: Stable training with configurable clipping
  • Early Stopping: Automatic training termination on plateau

πŸ“Š Dataset Format

Expected File Structure

data/
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β”‚   β”œβ”€β”€ 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
β”‚   └── ...
β”œβ”€β”€ val/
β”‚   └── ...
└── test/
    β”œβ”€β”€ DS_test_set/          # Delay Spread tests
    β”‚   β”œβ”€β”€ DS_50/
    β”‚   β”œβ”€β”€ DS_100/
    β”‚   └── ...
    β”œβ”€β”€ SNR_test_set/         # SNR tests
    β”‚   β”œβ”€β”€ SNR_10/
    β”‚   β”œβ”€β”€ SNR_20/
    β”‚   └── ...
    └── MDS_test_set/         # Multi-Doppler tests
        β”œβ”€β”€ DOP_200/
        β”œβ”€β”€ DOP_400/
        └── ...

File Naming Convention

Files must follow the pattern:

{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat

Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat

Data Format

Each .mat file must contain variable H with shape [subcarriers, symbols, 3]:

  • H[:, :, 0]: Ground truth channel (complex values)
  • H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
  • H[:, :, 2]: Reserved for future use

πŸ”§ Usage Examples

Training Different Models

Linear Estimator:

python src/main.py \
    --model_name linear \
    --system_config_path config/system_config.yaml \
    --model_config_path config/linear.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id linear_baseline

FortiTran:

python src/main.py \
    --model_name fortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/fortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id fortitran_experiment

AdaFortiTran:

python src/main.py \
    --model_name adafortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/adafortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id adafortitran_experiment

Resume Training

python src/main.py \
    --model_name adafortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/adafortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id resumed_experiment \
    --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt

Hyperparameter Tuning

python src/main.py \
    --model_name adafortitran \
    --system_config_path config/system_config.yaml \
    --model_config_path config/adafortitran.yaml \
    --train_set data/train \
    --val_set data/val \
    --test_set data/test \
    --exp_id hyperparameter_tuning \
    --batch_size 64 \
    --lr 1e-3 \
    --max_epoch 50 \
    --patience 5 \
    --weight_decay 1e-5 \
    --gradient_clip_val 0.5 \
    --use_mixed_precision \
    --test_every_n 5

πŸ“ˆ Monitoring and Logging

TensorBoard Integration

Training automatically logs metrics to TensorBoard:

tensorboard --logdir runs/

Available metrics:

  • Training/validation loss
  • Learning rate
  • Test performance across conditions
  • Error visualizations
  • Model hyperparameters

Log Files

Training logs are saved to:

  • logs/training_{exp_id}.log: Python logging output
  • runs/{model_name}_{exp_id}/: TensorBoard logs and checkpoints

πŸ§ͺ Testing and Evaluation

Automatic Testing

The training pipeline automatically evaluates models on:

  • DS (Delay Spread): Varying delay spread conditions
  • SNR: Different signal-to-noise ratios
  • MDS (Multi-Doppler): Various Doppler shift scenarios

Manual Evaluation

from src.models import AdaFortiTranEstimator
from src.config import load_config

# Load configurations
system_config, model_config = load_config(
    'config/system_config.yaml', 
    'config/adafortitran.yaml'
)

# Initialize model
model = AdaFortiTranEstimator(system_config, model_config)

# Load checkpoint
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate
model.eval()
# ... evaluation code

πŸ”¬ Research and Development

Adding Custom Callbacks

from src.main.trainer import Callback, TrainingMetrics

class CustomCallback(Callback):
    def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
        # Custom logic here
        print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")

Extending Models

The modular architecture makes it easy to add new model variants:

from src.models.fortitran import BaseFortiTranEstimator

class CustomEstimator(BaseFortiTranEstimator):
    def __init__(self, system_config, model_config):
        super().__init__(system_config, model_config, use_channel_adaptation=True)
        # Add custom components

πŸ› Troubleshooting

Common Issues

CUDA Out of Memory:

  • Reduce batch size: --batch_size 32
  • Enable mixed precision: --use_mixed_precision
  • Reduce number of workers: --num_workers 2

Slow Training:

  • Increase number of workers: --num_workers 8
  • Enable pin memory: --pin_memory
  • Use mixed precision: --use_mixed_precision

Poor Convergence:

  • Adjust learning rate: --lr 1e-4
  • Add gradient clipping: --gradient_clip_val 1.0
  • Increase patience: --patience 10

Getting Help

  1. Check the logs in logs/training_{exp_id}.log
  2. Verify dataset format matches requirements
  3. Ensure all dependencies are installed correctly
  4. Check TensorBoard for training curves

πŸ“š Citation

If you use this code in your research, please cite:

@misc{guler2025adafortitranadaptivetransformermodel,
      title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation}, 
      author={Berkay Guler and Hamid Jafarkhani},
      year={2025},
      eprint={2505.09076},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2505.09076}, 
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

Copyright (c) 2025 [Berkay Guler/University of California, Irvine]