BerkIGuler commited on
Commit
8a42349
·
1 Parent(s): 23330c3

minor fixes on src/config

Browse files
README.md CHANGED
@@ -1,4 +1,4 @@
1
- # Official implementation of ICC 2025 paper [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076)
2
 
3
 
4
  ## License
 
1
+ # Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
2
 
3
 
4
  ## License
config/adafortitran.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  patch_size: [3, 2]
2
  num_layers: 6
3
  model_dim: 128
 
1
+ model_type: 'adafortitran'
2
  patch_size: [3, 2]
3
  num_layers: 6
4
  model_dim: 128
config/fortitran.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  patch_size: [3, 2]
2
  num_layers: 6
3
  model_dim: 128
@@ -5,5 +6,4 @@ num_head: 4
5
  activation: 'gelu'
6
  dropout: 0.1
7
  max_seq_len: 512
8
- pos_encoding_type: 'learnable'
9
- adaptive_token_length: 6
 
1
+ model_type: 'fortitran'
2
  patch_size: [3, 2]
3
  num_layers: 6
4
  model_dim: 128
 
6
  activation: 'gelu'
7
  dropout: 0.1
8
  max_seq_len: 512
9
+ pos_encoding_type: 'learnable'
 
scripts/add_gitkeep.py CHANGED
@@ -44,7 +44,6 @@ def add_gitkeep_to_directories(root_path: str | Path):
44
  print(f"\nTotal .gitkeep files added: {gitkeep_count}")
45
 
46
  if __name__ == "__main__":
47
- # Add .gitkeep to all subdirectories in the data folder
48
  data_path = Path("data")
49
 
50
  print(f"Adding .gitkeep files to subdirectories in {data_path.absolute()}")
 
44
  print(f"\nTotal .gitkeep files added: {gitkeep_count}")
45
 
46
  if __name__ == "__main__":
 
47
  data_path = Path("data")
48
 
49
  print(f"Adding .gitkeep files to subdirectories in {data_path.absolute()}")
src/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (156 Bytes)
 
src/config/__init__.py CHANGED
@@ -1 +1,5 @@
1
- from src.config.schemas import ModelConfig, SystemConfig
 
 
 
 
 
1
+ """This module provides a clean interface for loading and validating configuration files."""
2
+
3
+ from .config_loader import load_config
4
+
5
+ __all__ = ["load_config"]
src/config/config_loader.py CHANGED
@@ -29,6 +29,8 @@ class ConfigLoader:
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
 
 
32
  model_config = None
33
  if model_config_path is not None:
34
  model_config_path = Path(model_config_path)
@@ -48,7 +50,6 @@ class ConfigLoader:
48
  except ValidationError as e:
49
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
50
 
51
- # Only load model config if path is provided and file exists
52
  if model_config_path is not None and model_config_path.exists():
53
  try:
54
  with open(model_config_path, 'r') as f:
 
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
32
+
33
+ # certain models may not have a model config
34
  model_config = None
35
  if model_config_path is not None:
36
  model_config_path = Path(model_config_path)
 
50
  except ValidationError as e:
51
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
52
 
 
53
  if model_config_path is not None and model_config_path.exists():
54
  try:
55
  with open(model_config_path, 'r') as f:
src/config/schemas.py CHANGED
@@ -1,14 +1,18 @@
1
  from pydantic import BaseModel, Field, model_validator
2
- from typing import Self, Tuple, List, Optional
3
  import torch
4
 
5
 
6
  class OFDMParams(BaseModel):
 
 
7
  num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
8
  num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
9
 
10
 
11
  class PilotParams(BaseModel):
 
 
12
  num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
13
  num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
14
 
@@ -17,7 +21,7 @@ class SystemConfig(BaseModel):
17
  ofdm: OFDMParams
18
  pilot: PilotParams
19
 
20
- @model_validator(mode='after')
21
  def validate_pilot_constraints(self) -> Self:
22
  """Ensure pilot parameters don't exceed OFDM parameters."""
23
  if self.pilot.num_scs > self.ofdm.num_scs:
@@ -33,25 +37,62 @@ class SystemConfig(BaseModel):
33
  )
34
  return self
35
 
36
- model_config = {"extra": "forbid"}
37
 
38
 
39
  class ModelConfig(BaseModel):
 
 
 
 
40
  patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
41
  num_layers: int = Field(..., gt=0, description="Number of transformer layers")
42
  model_dim: int = Field(..., gt=0, description="Model dimension")
43
  num_head: int = Field(..., gt=0, description="Number of attention heads")
44
- activation: str = Field(default="gelu", description="Activation function")
45
- dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
 
 
 
46
  max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
47
- pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
48
- adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
 
 
 
 
 
 
 
49
  channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
50
  default=None,
51
- description="Hidden sizes for channel adaptation layers"
52
  )
53
  device: str = Field(default="cpu", description="Device to use")
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @model_validator(mode='after')
56
  def validate_device(self) -> Self:
57
  """Validate that the specified device is available."""
@@ -67,7 +108,6 @@ class ModelConfig(BaseModel):
67
  self.device = 'cpu'
68
  return self
69
 
70
- # Validate CPU
71
  if device_str == 'cpu':
72
  return self
73
 
 
1
  from pydantic import BaseModel, Field, model_validator
2
+ from typing import Self, Tuple, List, Optional, Literal
3
  import torch
4
 
5
 
6
  class OFDMParams(BaseModel):
7
+ # ... means required (i.e. no default value)
8
+ # gt=0 means greater than 0
9
  num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
10
  num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
11
 
12
 
13
  class PilotParams(BaseModel):
14
+ # ... means required (i.e. no default value)
15
+ # gt=0 means greater than 0
16
  num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
17
  num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
18
 
 
21
  ofdm: OFDMParams
22
  pilot: PilotParams
23
 
24
+ @model_validator(mode='after') # validate after all fields are initialized
25
  def validate_pilot_constraints(self) -> Self:
26
  """Ensure pilot parameters don't exceed OFDM parameters."""
27
  if self.pilot.num_scs > self.ofdm.num_scs:
 
37
  )
38
  return self
39
 
40
+ model_config = {"extra": "forbid"} # forbid extra fields
41
 
42
 
43
  class ModelConfig(BaseModel):
44
+ model_type: Literal["fortitran", "adafortitran"] = Field(
45
+ default="fortitran",
46
+ description="Type of model (fortitran or adafortitran)"
47
+ )
48
  patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
49
  num_layers: int = Field(..., gt=0, description="Number of transformer layers")
50
  model_dim: int = Field(..., gt=0, description="Model dimension")
51
  num_head: int = Field(..., gt=0, description="Number of attention heads")
52
+ activation: Literal["relu", "gelu"] = Field(
53
+ default="gelu",
54
+ description="Activation function used within the transformer's FFN"
55
+ )
56
+ dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
57
  max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
58
+ pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
59
+ default="learnable",
60
+ description="Positional encoding type"
61
+ )
62
+ adaptive_token_length: Optional[int] = Field(
63
+ default=None,
64
+ gt=0,
65
+ description="Adaptive token length (required for AdaFortiTran)"
66
+ )
67
  channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
68
  default=None,
69
+ description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
70
  )
71
  device: str = Field(default="cpu", description="Device to use")
72
 
73
+ @model_validator(mode='after')
74
+ def validate_model_specific_requirements(self) -> Self:
75
+ """Validate model-specific configuration requirements."""
76
+ if self.model_type == "adafortitran":
77
+ if self.channel_adaptivity_hidden_sizes is None:
78
+ raise ValueError(
79
+ "channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
80
+ )
81
+ if self.adaptive_token_length is None:
82
+ raise ValueError(
83
+ "adaptive_token_length is required for AdaFortiTran model"
84
+ )
85
+
86
+ if self.model_type == "fortitran":
87
+ if self.channel_adaptivity_hidden_sizes is not None:
88
+ # Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
89
+ pass
90
+ if self.adaptive_token_length is not None:
91
+ # Note: adaptive_token_length will be ignored for FortiTran
92
+ pass
93
+
94
+ return self
95
+
96
  @model_validator(mode='after')
97
  def validate_device(self) -> Self:
98
  """Validate that the specified device is available."""
 
108
  self.device = 'cpu'
109
  return self
110
 
 
111
  if device_str == 'cpu':
112
  return self
113
 
src/main.py CHANGED
@@ -13,7 +13,7 @@ from pathlib import Path
13
 
14
  from src.main.parser import parse_arguments
15
  from src.main.trainer import train
16
- from src.config.config_loader import load_config
17
 
18
 
19
  def setup_logging(log_level: str) -> None:
@@ -58,7 +58,10 @@ def main() -> None:
58
  logger.info("Configuration loaded successfully")
59
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
60
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
61
- logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
 
 
 
62
 
63
  # Start training
64
  logger.info("Initializing training...")
 
13
 
14
  from src.main.parser import parse_arguments
15
  from src.main.trainer import train
16
+ from src.config import load_config
17
 
18
 
19
  def setup_logging(log_level: str) -> None:
 
58
  logger.info("Configuration loaded successfully")
59
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
60
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
61
+ if model_config is not None:
62
+ logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
63
+ else:
64
+ logger.info("Using Linear model (no model config required)")
65
 
66
  # Start training
67
  logger.info("Initializing training...")