Commit
·
8a42349
1
Parent(s):
23330c3
minor fixes on src/config
Browse files- README.md +1 -1
- config/adafortitran.yaml +1 -0
- config/fortitran.yaml +2 -2
- scripts/add_gitkeep.py +0 -1
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__init__.py +5 -1
- src/config/config_loader.py +2 -1
- src/config/schemas.py +49 -9
- src/main.py +5 -2
README.md
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Official implementation of
|
| 2 |
|
| 3 |
|
| 4 |
## License
|
|
|
|
| 1 |
+
# Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
|
| 2 |
|
| 3 |
|
| 4 |
## License
|
config/adafortitran.yaml
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
patch_size: [3, 2]
|
| 2 |
num_layers: 6
|
| 3 |
model_dim: 128
|
|
|
|
| 1 |
+
model_type: 'adafortitran'
|
| 2 |
patch_size: [3, 2]
|
| 3 |
num_layers: 6
|
| 4 |
model_dim: 128
|
config/fortitran.yaml
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
patch_size: [3, 2]
|
| 2 |
num_layers: 6
|
| 3 |
model_dim: 128
|
|
@@ -5,5 +6,4 @@ num_head: 4
|
|
| 5 |
activation: 'gelu'
|
| 6 |
dropout: 0.1
|
| 7 |
max_seq_len: 512
|
| 8 |
-
pos_encoding_type: 'learnable'
|
| 9 |
-
adaptive_token_length: 6
|
|
|
|
| 1 |
+
model_type: 'fortitran'
|
| 2 |
patch_size: [3, 2]
|
| 3 |
num_layers: 6
|
| 4 |
model_dim: 128
|
|
|
|
| 6 |
activation: 'gelu'
|
| 7 |
dropout: 0.1
|
| 8 |
max_seq_len: 512
|
| 9 |
+
pos_encoding_type: 'learnable'
|
|
|
scripts/add_gitkeep.py
CHANGED
|
@@ -44,7 +44,6 @@ def add_gitkeep_to_directories(root_path: str | Path):
|
|
| 44 |
print(f"\nTotal .gitkeep files added: {gitkeep_count}")
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
| 47 |
-
# Add .gitkeep to all subdirectories in the data folder
|
| 48 |
data_path = Path("data")
|
| 49 |
|
| 50 |
print(f"Adding .gitkeep files to subdirectories in {data_path.absolute()}")
|
|
|
|
| 44 |
print(f"\nTotal .gitkeep files added: {gitkeep_count}")
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
|
|
|
| 47 |
data_path = Path("data")
|
| 48 |
|
| 49 |
print(f"Adding .gitkeep files to subdirectories in {data_path.absolute()}")
|
src/__pycache__/__init__.cpython-312.pyc
DELETED
|
Binary file (156 Bytes)
|
|
|
src/config/__init__.py
CHANGED
|
@@ -1 +1,5 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module provides a clean interface for loading and validating configuration files."""
|
| 2 |
+
|
| 3 |
+
from .config_loader import load_config
|
| 4 |
+
|
| 5 |
+
__all__ = ["load_config"]
|
src/config/config_loader.py
CHANGED
|
@@ -29,6 +29,8 @@ class ConfigLoader:
|
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
|
|
|
|
|
|
| 32 |
model_config = None
|
| 33 |
if model_config_path is not None:
|
| 34 |
model_config_path = Path(model_config_path)
|
|
@@ -48,7 +50,6 @@ class ConfigLoader:
|
|
| 48 |
except ValidationError as e:
|
| 49 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 50 |
|
| 51 |
-
# Only load model config if path is provided and file exists
|
| 52 |
if model_config_path is not None and model_config_path.exists():
|
| 53 |
try:
|
| 54 |
with open(model_config_path, 'r') as f:
|
|
|
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
| 32 |
+
|
| 33 |
+
# certain models may not have a model config
|
| 34 |
model_config = None
|
| 35 |
if model_config_path is not None:
|
| 36 |
model_config_path = Path(model_config_path)
|
|
|
|
| 50 |
except ValidationError as e:
|
| 51 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 52 |
|
|
|
|
| 53 |
if model_config_path is not None and model_config_path.exists():
|
| 54 |
try:
|
| 55 |
with open(model_config_path, 'r') as f:
|
src/config/schemas.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
-
from typing import Self, Tuple, List, Optional
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
| 6 |
class OFDMParams(BaseModel):
|
|
|
|
|
|
|
| 7 |
num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
|
| 8 |
num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
|
| 9 |
|
| 10 |
|
| 11 |
class PilotParams(BaseModel):
|
|
|
|
|
|
|
| 12 |
num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
|
| 13 |
num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
|
| 14 |
|
|
@@ -17,7 +21,7 @@ class SystemConfig(BaseModel):
|
|
| 17 |
ofdm: OFDMParams
|
| 18 |
pilot: PilotParams
|
| 19 |
|
| 20 |
-
@model_validator(mode='after')
|
| 21 |
def validate_pilot_constraints(self) -> Self:
|
| 22 |
"""Ensure pilot parameters don't exceed OFDM parameters."""
|
| 23 |
if self.pilot.num_scs > self.ofdm.num_scs:
|
|
@@ -33,25 +37,62 @@ class SystemConfig(BaseModel):
|
|
| 33 |
)
|
| 34 |
return self
|
| 35 |
|
| 36 |
-
model_config = {"extra": "forbid"}
|
| 37 |
|
| 38 |
|
| 39 |
class ModelConfig(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 41 |
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 42 |
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 43 |
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 44 |
-
activation:
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 47 |
-
pos_encoding_type:
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 50 |
default=None,
|
| 51 |
-
description="Hidden sizes for channel adaptation layers"
|
| 52 |
)
|
| 53 |
device: str = Field(default="cpu", description="Device to use")
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
@model_validator(mode='after')
|
| 56 |
def validate_device(self) -> Self:
|
| 57 |
"""Validate that the specified device is available."""
|
|
@@ -67,7 +108,6 @@ class ModelConfig(BaseModel):
|
|
| 67 |
self.device = 'cpu'
|
| 68 |
return self
|
| 69 |
|
| 70 |
-
# Validate CPU
|
| 71 |
if device_str == 'cpu':
|
| 72 |
return self
|
| 73 |
|
|
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
+
from typing import Self, Tuple, List, Optional, Literal
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
| 6 |
class OFDMParams(BaseModel):
|
| 7 |
+
# ... means required (i.e. no default value)
|
| 8 |
+
# gt=0 means greater than 0
|
| 9 |
num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
|
| 10 |
num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
|
| 11 |
|
| 12 |
|
| 13 |
class PilotParams(BaseModel):
|
| 14 |
+
# ... means required (i.e. no default value)
|
| 15 |
+
# gt=0 means greater than 0
|
| 16 |
num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
|
| 17 |
num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
|
| 18 |
|
|
|
|
| 21 |
ofdm: OFDMParams
|
| 22 |
pilot: PilotParams
|
| 23 |
|
| 24 |
+
@model_validator(mode='after') # validate after all fields are initialized
|
| 25 |
def validate_pilot_constraints(self) -> Self:
|
| 26 |
"""Ensure pilot parameters don't exceed OFDM parameters."""
|
| 27 |
if self.pilot.num_scs > self.ofdm.num_scs:
|
|
|
|
| 37 |
)
|
| 38 |
return self
|
| 39 |
|
| 40 |
+
model_config = {"extra": "forbid"} # forbid extra fields
|
| 41 |
|
| 42 |
|
| 43 |
class ModelConfig(BaseModel):
|
| 44 |
+
model_type: Literal["fortitran", "adafortitran"] = Field(
|
| 45 |
+
default="fortitran",
|
| 46 |
+
description="Type of model (fortitran or adafortitran)"
|
| 47 |
+
)
|
| 48 |
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 49 |
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 50 |
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 51 |
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 52 |
+
activation: Literal["relu", "gelu"] = Field(
|
| 53 |
+
default="gelu",
|
| 54 |
+
description="Activation function used within the transformer's FFN"
|
| 55 |
+
)
|
| 56 |
+
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
|
| 57 |
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 58 |
+
pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
|
| 59 |
+
default="learnable",
|
| 60 |
+
description="Positional encoding type"
|
| 61 |
+
)
|
| 62 |
+
adaptive_token_length: Optional[int] = Field(
|
| 63 |
+
default=None,
|
| 64 |
+
gt=0,
|
| 65 |
+
description="Adaptive token length (required for AdaFortiTran)"
|
| 66 |
+
)
|
| 67 |
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 68 |
default=None,
|
| 69 |
+
description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
|
| 70 |
)
|
| 71 |
device: str = Field(default="cpu", description="Device to use")
|
| 72 |
|
| 73 |
+
@model_validator(mode='after')
|
| 74 |
+
def validate_model_specific_requirements(self) -> Self:
|
| 75 |
+
"""Validate model-specific configuration requirements."""
|
| 76 |
+
if self.model_type == "adafortitran":
|
| 77 |
+
if self.channel_adaptivity_hidden_sizes is None:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
|
| 80 |
+
)
|
| 81 |
+
if self.adaptive_token_length is None:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"adaptive_token_length is required for AdaFortiTran model"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if self.model_type == "fortitran":
|
| 87 |
+
if self.channel_adaptivity_hidden_sizes is not None:
|
| 88 |
+
# Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
|
| 89 |
+
pass
|
| 90 |
+
if self.adaptive_token_length is not None:
|
| 91 |
+
# Note: adaptive_token_length will be ignored for FortiTran
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
@model_validator(mode='after')
|
| 97 |
def validate_device(self) -> Self:
|
| 98 |
"""Validate that the specified device is available."""
|
|
|
|
| 108 |
self.device = 'cpu'
|
| 109 |
return self
|
| 110 |
|
|
|
|
| 111 |
if device_str == 'cpu':
|
| 112 |
return self
|
| 113 |
|
src/main.py
CHANGED
|
@@ -13,7 +13,7 @@ from pathlib import Path
|
|
| 13 |
|
| 14 |
from src.main.parser import parse_arguments
|
| 15 |
from src.main.trainer import train
|
| 16 |
-
from src.config
|
| 17 |
|
| 18 |
|
| 19 |
def setup_logging(log_level: str) -> None:
|
|
@@ -58,7 +58,10 @@ def main() -> None:
|
|
| 58 |
logger.info("Configuration loaded successfully")
|
| 59 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 60 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Start training
|
| 64 |
logger.info("Initializing training...")
|
|
|
|
| 13 |
|
| 14 |
from src.main.parser import parse_arguments
|
| 15 |
from src.main.trainer import train
|
| 16 |
+
from src.config import load_config
|
| 17 |
|
| 18 |
|
| 19 |
def setup_logging(log_level: str) -> None:
|
|
|
|
| 58 |
logger.info("Configuration loaded successfully")
|
| 59 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 60 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 61 |
+
if model_config is not None:
|
| 62 |
+
logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
|
| 63 |
+
else:
|
| 64 |
+
logger.info("Using Linear model (no model config required)")
|
| 65 |
|
| 66 |
# Start training
|
| 67 |
logger.info("Initializing training...")
|