BerkIGuler commited on
Commit
687eaba
·
1 Parent(s): 9727e5e

fixes on src/models

Browse files
config/linear.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model_type: 'linear'
2
+ device: 'cpu'
src/config/config_loader.py CHANGED
@@ -13,31 +13,30 @@ class ConfigLoader:
13
  def __init__(self):
14
  self.logger = logging.getLogger(__name__)
15
 
16
- def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
17
  """
18
  Load and validate configuration files from YAML files.
19
 
20
  Args:
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
- model_config_path: Optional path to YAML configuration file for model-related parameters
23
 
24
  Returns:
25
- Tuple of (SystemConfig, Optional[ModelConfig]): Validated configuration objects
26
 
27
  Raises:
28
- FileNotFoundError: If the system config file doesn't exist
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
32
-
33
- # certain models may not have a model config
34
- model_config = None
35
- if model_config_path is not None:
36
- model_config_path = Path(model_config_path)
37
 
38
  if not system_config_path.exists():
39
  raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
40
 
 
 
 
41
  try:
42
  with open(system_config_path, 'r') as f:
43
  system_raw_config = yaml.safe_load(f)
@@ -50,22 +49,22 @@ class ConfigLoader:
50
  except ValidationError as e:
51
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
52
 
53
- if model_config_path is not None and model_config_path.exists():
54
- try:
55
- with open(model_config_path, 'r') as f:
56
- model_raw_config = yaml.safe_load(f)
57
- except yaml.YAMLError as e:
58
- raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
59
- try:
60
- model_config = ModelConfig(**model_raw_config)
61
- self.logger.info(f"Successfully loaded model config from {model_config_path}")
62
- except ValidationError as e:
63
- raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
64
 
65
  return system_config, model_config
66
 
67
 
68
- def load_config(system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
69
  """Convenience function to load and validate config."""
70
  config_loader = ConfigLoader()
71
  return config_loader.load_and_validate(system_config_path, model_config_path)
 
13
  def __init__(self):
14
  self.logger = logging.getLogger(__name__)
15
 
16
+ def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
17
  """
18
  Load and validate configuration files from YAML files.
19
 
20
  Args:
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
+ model_config_path: Path to YAML configuration file for model-related parameters
23
 
24
  Returns:
25
+ Tuple of (SystemConfig, ModelConfig): Validated configuration objects
26
 
27
  Raises:
28
+ FileNotFoundError: If either config file doesn't exist
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
32
+ model_config_path = Path(model_config_path)
 
 
 
 
33
 
34
  if not system_config_path.exists():
35
  raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
36
 
37
+ if not model_config_path.exists():
38
+ raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
39
+
40
  try:
41
  with open(system_config_path, 'r') as f:
42
  system_raw_config = yaml.safe_load(f)
 
49
  except ValidationError as e:
50
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
51
 
52
+ try:
53
+ with open(model_config_path, 'r') as f:
54
+ model_raw_config = yaml.safe_load(f)
55
+ except yaml.YAMLError as e:
56
+ raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
57
+
58
+ try:
59
+ model_config = ModelConfig(**model_raw_config)
60
+ self.logger.info(f"Successfully loaded model config from {model_config_path}")
61
+ except ValidationError as e:
62
+ raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
63
 
64
  return system_config, model_config
65
 
66
 
67
+ def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
68
  """Convenience function to load and validate config."""
69
  config_loader = ConfigLoader()
70
  return config_loader.load_and_validate(system_config_path, model_config_path)
src/config/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
  from pydantic import BaseModel, Field, model_validator
2
- from typing import Self, Tuple, List, Optional, Literal
3
  import torch
4
 
5
 
@@ -40,59 +40,11 @@ class SystemConfig(BaseModel):
40
  model_config = {"extra": "forbid"} # forbid extra fields
41
 
42
 
43
- class ModelConfig(BaseModel):
44
- model_type: Literal["fortitran", "adafortitran"] = Field(
45
- default="fortitran",
46
- description="Type of model (fortitran or adafortitran)"
47
- )
48
- patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
49
- num_layers: int = Field(..., gt=0, description="Number of transformer layers")
50
- model_dim: int = Field(..., gt=0, description="Model dimension")
51
- num_head: int = Field(..., gt=0, description="Number of attention heads")
52
- activation: Literal["relu", "gelu"] = Field(
53
- default="gelu",
54
- description="Activation function used within the transformer's FFN"
55
- )
56
- dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
57
- max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
58
- pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
59
- default="learnable",
60
- description="Positional encoding type"
61
- )
62
- adaptive_token_length: Optional[int] = Field(
63
- default=None,
64
- gt=0,
65
- description="Adaptive token length (required for AdaFortiTran)"
66
- )
67
- channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
68
- default=None,
69
- description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
70
- )
71
  device: str = Field(default="cpu", description="Device to use")
72
 
73
- @model_validator(mode='after')
74
- def validate_model_specific_requirements(self) -> Self:
75
- """Validate model-specific configuration requirements."""
76
- if self.model_type == "adafortitran":
77
- if self.channel_adaptivity_hidden_sizes is None:
78
- raise ValueError(
79
- "channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
80
- )
81
- if self.adaptive_token_length is None:
82
- raise ValueError(
83
- "adaptive_token_length is required for AdaFortiTran model"
84
- )
85
-
86
- if self.model_type == "fortitran":
87
- if self.channel_adaptivity_hidden_sizes is not None:
88
- # Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
89
- pass
90
- if self.adaptive_token_length is not None:
91
- # Note: adaptive_token_length will be ignored for FortiTran
92
- pass
93
-
94
- return self
95
-
96
  @model_validator(mode='after')
97
  def validate_device(self) -> Self:
98
  """Validate that the specified device is available."""
@@ -152,4 +104,59 @@ class ModelConfig(BaseModel):
152
  f"Available devices: {available_devices}"
153
  )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  model_config = {"extra": "forbid"}
 
1
  from pydantic import BaseModel, Field, model_validator
2
+ from typing import Self, Tuple, List, Optional, Literal, Union
3
  import torch
4
 
5
 
 
40
  model_config = {"extra": "forbid"} # forbid extra fields
41
 
42
 
43
+ class BaseConfig(BaseModel):
44
+ """Base configuration class with device validation."""
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  device: str = Field(default="cpu", description="Device to use")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @model_validator(mode='after')
49
  def validate_device(self) -> Self:
50
  """Validate that the specified device is available."""
 
104
  f"Available devices: {available_devices}"
105
  )
106
 
107
+
108
+ class ModelConfig(BaseConfig):
109
+ model_type: Literal["linear", "fortitran", "adafortitran"] = Field(
110
+ default="fortitran",
111
+ description="Type of model (linear, fortitran, or adafortitran)"
112
+ )
113
+ patch_size: Tuple[int, int] = Field(..., description="Patch size as (subcarriers_per_patch, symbols_per_patch)")
114
+ num_layers: int = Field(..., gt=0, description="Number of transformer layers")
115
+ model_dim: int = Field(..., gt=0, description="Model dimension")
116
+ num_head: int = Field(..., gt=0, description="Number of attention heads")
117
+ activation: Literal["relu", "gelu"] = Field(
118
+ default="gelu",
119
+ description="Activation function used within the transformer's FFN"
120
+ )
121
+ dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
122
+ max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
123
+ pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
124
+ default="learnable",
125
+ description="Positional encoding type"
126
+ )
127
+ adaptive_token_length: Optional[int] = Field(
128
+ default=None,
129
+ gt=0,
130
+ description="Adaptive token length (required for AdaFortiTran)"
131
+ )
132
+ channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
133
+ default=None,
134
+ description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
135
+ )
136
+
137
+ @model_validator(mode='after')
138
+ def validate_model_specific_requirements(self) -> Self:
139
+ """Validate model-specific configuration requirements."""
140
+ if self.model_type == "linear":
141
+ # Linear model only needs device, no additional validation required
142
+ pass
143
+ elif self.model_type == "adafortitran":
144
+ if self.channel_adaptivity_hidden_sizes is None:
145
+ raise ValueError(
146
+ "channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
147
+ )
148
+ if self.adaptive_token_length is None:
149
+ raise ValueError(
150
+ "adaptive_token_length is required for AdaFortiTran model"
151
+ )
152
+ elif self.model_type == "fortitran":
153
+ if self.channel_adaptivity_hidden_sizes is not None:
154
+ # Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
155
+ pass
156
+ if self.adaptive_token_length is not None:
157
+ # Note: adaptive_token_length will be ignored for FortiTran
158
+ pass
159
+
160
+ return self
161
+
162
  model_config = {"extra": "forbid"}
src/main.py CHANGED
@@ -44,6 +44,7 @@ from pathlib import Path
44
  from src.main.parser import parse_arguments
45
  from src.main.trainer import train
46
  from src.config import load_config
 
47
 
48
 
49
  def setup_logging(log_level: str) -> None:
@@ -88,10 +89,10 @@ def main() -> None:
88
  logger.info("Configuration loaded successfully")
89
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
90
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
91
- if model_config is not None:
92
- logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
93
  else:
94
- logger.info("Using Linear model (no model config required)")
95
 
96
  # Start training
97
  logger.info("Initializing training...")
 
44
  from src.main.parser import parse_arguments
45
  from src.main.trainer import train
46
  from src.config import load_config
47
+ from src.config.schemas import ModelConfig
48
 
49
 
50
  def setup_logging(log_level: str) -> None:
 
89
  logger.info("Configuration loaded successfully")
90
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
91
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
92
+ if model_config.model_type == "linear":
93
+ logger.info(f"Linear model with device: {model_config.device}")
94
  else:
95
+ logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
96
 
97
  # Start training
98
  logger.info("Initializing training...")
src/main/trainer.py CHANGED
@@ -69,7 +69,7 @@ class ModelTrainer:
69
 
70
  EXP_LR_GAMMA = 0.995
71
 
72
- def __init__(self, system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments):
73
  """
74
  Initialize the ModelTrainer.
75
 
@@ -121,14 +121,10 @@ class ModelTrainer:
121
  Initialized model instance of the specified type
122
  """
123
  if self.args.model_name == "linear":
124
- model = LinearEstimator(self.system_config, device=str(self.device))
125
  elif self.args.model_name == "adafortitran":
126
- if self.model_config is None:
127
- raise ValueError("model_config must be provided for AdaFortiTranEstimator.")
128
  model = AdaFortiTranEstimator(self.system_config, self.model_config)
129
  elif self.args.model_name == "fortitran":
130
- if self.model_config is None:
131
- raise ValueError("model_config must be provided for FortiTranEstimator.")
132
  model = FortiTranEstimator(self.system_config, self.model_config)
133
  else:
134
  raise ValueError(f"Unknown model name: {self.args.model_name}")
@@ -406,7 +402,7 @@ class ModelTrainer:
406
  self.writer.close()
407
 
408
 
409
- def train(system_config: SystemConfig, model_config: ModelConfig | None, args: TrainingArguments) -> None:
410
  """
411
  Train an OFDM channel estimation model.
412
 
 
69
 
70
  EXP_LR_GAMMA = 0.995
71
 
72
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments):
73
  """
74
  Initialize the ModelTrainer.
75
 
 
121
  Initialized model instance of the specified type
122
  """
123
  if self.args.model_name == "linear":
124
+ model = LinearEstimator(self.system_config, self.model_config)
125
  elif self.args.model_name == "adafortitran":
 
 
126
  model = AdaFortiTranEstimator(self.system_config, self.model_config)
127
  elif self.args.model_name == "fortitran":
 
 
128
  model = FortiTranEstimator(self.system_config, self.model_config)
129
  else:
130
  raise ValueError(f"Unknown model name: {self.args.model_name}")
 
402
  self.writer.close()
403
 
404
 
405
+ def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
406
  """
407
  Train an OFDM channel estimation model.
408
 
src/models/blocks/enhancers.py CHANGED
@@ -23,9 +23,9 @@ class ConvEnhancer(nn.Module):
23
  """Forward pass through the convolutional enhancement network.
24
 
25
  Args:
26
- x (torch.Tensor): Input tensor of shape (batch_size, 1, height, width)
27
 
28
  Returns:
29
- torch.Tensor: Enhanced tensor of shape (batch_size, 1, height, width)
30
  """
31
  return self.conv_block(x)
 
23
  """Forward pass through the convolutional enhancement network.
24
 
25
  Args:
26
+ x (torch.Tensor): Input tensor of shape (batch_size, 1, num_subcarriers, num_symbols)
27
 
28
  Returns:
29
+ torch.Tensor: Enhanced tensor of shape (batch_size, 1, num_subcarriers, num_symbols)
30
  """
31
  return self.conv_block(x)
src/models/blocks/patch_processors.py CHANGED
@@ -15,7 +15,7 @@ class PatchEmbedding(nn.Module):
15
  """Initialize the PatchEmbedding layer.
16
 
17
  Args:
18
- patch_size: Size of patches to extract (height, width)
19
  """
20
  super().__init__()
21
  self.patch_size = patch_size
@@ -25,11 +25,11 @@ class PatchEmbedding(nn.Module):
25
  """Transform input tensor into patch embeddings.
26
 
27
  Args:
28
- x: Input tensor of shape (batch_size, height, width)
29
 
30
  Returns:
31
  Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
32
- where num_patches = (height // patch_size[0]) * (width // patch_size[1])
33
  """
34
  x = self.unfold(torch.unsqueeze(x, dim=1))
35
  return torch.permute(x, dims=(0, 2, 1))
@@ -46,8 +46,8 @@ class InversePatchEmbedding(nn.Module):
46
  """Initialize the InversePatchEmbedding layer.
47
 
48
  Args:
49
- output_size: Size of output matrix (height, width)
50
- patch_size: Size of input patches (height, width)
51
  """
52
  super().__init__()
53
  self.fold = nn.Fold(
@@ -64,7 +64,7 @@ class InversePatchEmbedding(nn.Module):
64
  where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1])
65
 
66
  Returns:
67
- Tensor of shape (batch_size, output_size[0], output_size[1])
68
  """
69
  x = torch.permute(x, dims=(0, 2, 1))
70
  x = self.fold(x)
 
15
  """Initialize the PatchEmbedding layer.
16
 
17
  Args:
18
+ patch_size: Size of patches to extract (subcarriers_per_patch, symbols_per_patch)
19
  """
20
  super().__init__()
21
  self.patch_size = patch_size
 
25
  """Transform input tensor into patch embeddings.
26
 
27
  Args:
28
+ x: Input tensor of shape (batch_size, num_subcarriers, num_symbols)
29
 
30
  Returns:
31
  Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
32
+ where num_patches = (num_subcarriers // patch_size[0]) * (num_symbols // patch_size[1])
33
  """
34
  x = self.unfold(torch.unsqueeze(x, dim=1))
35
  return torch.permute(x, dims=(0, 2, 1))
 
46
  """Initialize the InversePatchEmbedding layer.
47
 
48
  Args:
49
+ output_size: Size of output matrix (num_subcarriers, num_symbols)
50
+ patch_size: Size of input patches (subcarriers_per_patch, symbols_per_patch)
51
  """
52
  super().__init__()
53
  self.fold = nn.Fold(
 
64
  where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1])
65
 
66
  Returns:
67
+ Tensor of shape (batch_size, num_subcarriers, num_symbols)
68
  """
69
  x = torch.permute(x, dims=(0, 2, 1))
70
  x = self.fold(x)
src/models/fortitran.py CHANGED
@@ -4,8 +4,7 @@ import logging
4
  from typing import Tuple, List, Optional
5
 
6
  from src.config.schemas import SystemConfig, ModelConfig
7
- from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, \
8
- ChannelAdapter
9
 
10
 
11
  class BaseFortiTranEstimator(nn.Module):
@@ -13,11 +12,11 @@ class BaseFortiTranEstimator(nn.Module):
13
  Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
14
 
15
  This model performs channel estimation by:
16
- 1. Upsampling pilot symbols to full OFDM grid size
17
- 2. Applying convolutional enhancement for spatial features
18
  3. Converting to patch embeddings for transformer processing
19
  4. Using transformer encoder to capture long-range dependencies
20
- 5. Reconstructing spatial representation and applying residual connections
21
  6. Final convolutional refinement for high-quality channel estimates
22
  """
23
 
@@ -29,7 +28,7 @@ class BaseFortiTranEstimator(nn.Module):
29
  Args:
30
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
31
  model_config: Model architecture configuration (patch size, layers, etc.)
32
- use_channel_adaptation: Whether to enable channel adaptation features
33
  """
34
  super().__init__()
35
 
@@ -73,11 +72,13 @@ class BaseFortiTranEstimator(nn.Module):
73
  self.model_config.patch_size[0] * self.model_config.patch_size[1]
74
  )
75
 
76
- # Adaptive patch length (only used if channel adaptation is enabled)
77
  if self.use_channel_adaptation:
78
- self.adaptive_patch_length = self.patch_length + self.model_config.adaptive_token_length
 
 
79
  else:
80
- self.adaptive_patch_length = self.patch_length
81
 
82
  def _build_architecture(self) -> None:
83
  """Construct the model architecture components."""
@@ -92,14 +93,19 @@ class BaseFortiTranEstimator(nn.Module):
92
 
93
  # 4. Channel adapter (conditional on use_channel_adaptation)
94
  if self.use_channel_adaptation:
95
- self.channel_adapter = ChannelAdapter(self.model_config.channel_adaptivity_hidden_sizes)
 
 
 
 
 
 
96
 
97
  # 5. Transformer encoder for sequence modeling
98
- transformer_input_dim = self.adaptive_patch_length if self.use_channel_adaptation else self.patch_length
99
  transformer_output_dim = self.patch_length # Always output standard patch length
100
 
101
  self.transformer_encoder = TransformerEncoderForChannels(
102
- input_dim=transformer_input_dim,
103
  output_dim=transformer_output_dim,
104
  model_dim=self.model_config.model_dim,
105
  num_head=self.model_config.num_head,
@@ -189,7 +195,7 @@ class BaseFortiTranEstimator(nn.Module):
189
  """
190
  batch_size = x.shape[0]
191
 
192
- # Flatten spatial dimensions for linear upsampling
193
  if x.dim() > 2:
194
  x = x.view(batch_size, -1)
195
 
@@ -215,7 +221,7 @@ class BaseFortiTranEstimator(nn.Module):
215
  # Stage 5: Transformer processing for long-range dependencies
216
  transformer_output = self.transformer_encoder(transformer_input)
217
 
218
- # Stage 6: Reconstruct spatial representation
219
  reconstructed = self.patch_reconstructor(transformer_output)
220
 
221
  # Stage 7: Apply residual connection
@@ -235,7 +241,7 @@ class BaseFortiTranEstimator(nn.Module):
235
  'pilot_size': self.pilot_size,
236
  'patch_size': self.model_config.patch_size,
237
  'patch_length': self.patch_length,
238
- 'adaptive_patch_length': self.adaptive_patch_length,
239
  'model_dim': self.model_config.model_dim,
240
  'num_layers': self.model_config.num_layers,
241
  'device': str(self.device),
 
4
  from typing import Tuple, List, Optional
5
 
6
  from src.config.schemas import SystemConfig, ModelConfig
7
+ from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
 
8
 
9
 
10
  class BaseFortiTranEstimator(nn.Module):
 
12
  Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
13
 
14
  This model performs channel estimation by:
15
+ 1. Upsampling pilot symbols to full OFDM grid size (with linear layer)
16
+ 2. Applying convolutional enhancement for subcarrier-symbol features
17
  3. Converting to patch embeddings for transformer processing
18
  4. Using transformer encoder to capture long-range dependencies
19
+ 5. Reconstructing subcarrier-symbol representation and applying residual connections
20
  6. Final convolutional refinement for high-quality channel estimates
21
  """
22
 
 
28
  Args:
29
  system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
30
  model_config: Model architecture configuration (patch size, layers, etc.)
31
+ use_channel_adaptation: Whether to enable channel adaptation features (disabled for FortiTran)
32
  """
33
  super().__init__()
34
 
 
72
  self.model_config.patch_size[0] * self.model_config.patch_size[1]
73
  )
74
 
75
+ # Transformer input dimension (includes channel tokens if adaptation is enabled)
76
  if self.use_channel_adaptation:
77
+ if self.model_config.adaptive_token_length is None:
78
+ raise ValueError("adaptive_token_length must be set when channel adaptation is enabled")
79
+ self.transformer_input_dim = self.patch_length + self.model_config.adaptive_token_length
80
  else:
81
+ self.transformer_input_dim = self.patch_length
82
 
83
  def _build_architecture(self) -> None:
84
  """Construct the model architecture components."""
 
93
 
94
  # 4. Channel adapter (conditional on use_channel_adaptation)
95
  if self.use_channel_adaptation:
96
+ if self.model_config.channel_adaptivity_hidden_sizes is None:
97
+ raise ValueError("channel_adaptivity_hidden_sizes must be set when channel adaptation is enabled")
98
+ # Convert list to tuple as expected by ChannelAdapter (exactly 3 values)
99
+ hidden_sizes = tuple(self.model_config.channel_adaptivity_hidden_sizes)
100
+ if len(hidden_sizes) != 3:
101
+ raise ValueError("channel_adaptivity_hidden_sizes must have exactly 3 values")
102
+ self.channel_adapter = ChannelAdapter(hidden_sizes)
103
 
104
  # 5. Transformer encoder for sequence modeling
 
105
  transformer_output_dim = self.patch_length # Always output standard patch length
106
 
107
  self.transformer_encoder = TransformerEncoderForChannels(
108
+ input_dim=self.transformer_input_dim,
109
  output_dim=transformer_output_dim,
110
  model_dim=self.model_config.model_dim,
111
  num_head=self.model_config.num_head,
 
195
  """
196
  batch_size = x.shape[0]
197
 
198
+ # Flatten subcarrier and symbol dimensions for linear upsampling
199
  if x.dim() > 2:
200
  x = x.view(batch_size, -1)
201
 
 
221
  # Stage 5: Transformer processing for long-range dependencies
222
  transformer_output = self.transformer_encoder(transformer_input)
223
 
224
+ # Stage 6: Reconstruct subcarrier-symbol representation
225
  reconstructed = self.patch_reconstructor(transformer_output)
226
 
227
  # Stage 7: Apply residual connection
 
241
  'pilot_size': self.pilot_size,
242
  'patch_size': self.model_config.patch_size,
243
  'patch_length': self.patch_length,
244
+ 'transformer_input_dim': self.transformer_input_dim,
245
  'model_dim': self.model_config.model_dim,
246
  'num_layers': self.model_config.num_layers,
247
  'device': str(self.device),
src/models/linear.py CHANGED
@@ -10,43 +10,47 @@ import logging
10
  import torch
11
  import torch.nn as nn
12
 
13
- from src.config.schemas import SystemConfig
14
 
15
 
16
  class LinearEstimator(nn.Module):
17
  """Learned MMSE estimator.
18
 
 
 
19
  Attributes:
20
  device (torch.device): Target device for computation
21
- config (SystemConfig): Validated configuration object
22
- ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (height, width)
23
- height (int): number of sub-carriers
24
- width (int): number of OFDM symbols
25
- pilot_size (Tuple[int, int]): Dimensions of pilot signal as (height, width)
26
- height (int): number of pilots across sub-carriers
27
- width (int): number of pilots across OFDM symbols
 
28
  """
29
 
30
- def __init__(self, config: SystemConfig, device: str = "cpu") -> None:
31
  """Initialize the MMSE estimator.
32
 
33
  Args:
34
- config: Validated SystemConfig object containing OFDM system parameters
35
- device: Device to use for computation (cpu, cuda, etc.)
36
  """
37
  super().__init__()
38
 
39
- self.config = config
40
- self.device = torch.device(device)
 
41
  self.logger = logging.getLogger(__name__)
42
 
43
  # Extract dimensions from validated config
44
- self.ofdm_size = (config.ofdm.num_scs, config.ofdm.num_symbols)
45
- self.pilot_size = (config.pilot.num_scs, config.pilot.num_symbols)
46
 
47
  # Calculate feature dimensions
48
- in_feature_dim = config.pilot.num_scs * config.pilot.num_symbols
49
- out_feature_dim = config.ofdm.num_scs * config.ofdm.num_symbols
50
 
51
  self.logger.info(f"Initializing LinearEstimator:")
52
  self.logger.info(f" OFDM size: {self.ofdm_size}")
@@ -70,7 +74,7 @@ class LinearEstimator(nn.Module):
70
  Estimated OFDM signal tensor with shape
71
  (batch_size, ofdm_size[0], ofdm_size[1])
72
  """
73
- # pytorch does nothin if input is already on correct device
74
  x = x.to(self.device)
75
  self.logger.debug(f"Input shape: {x.size()}")
76
 
@@ -95,14 +99,6 @@ class LinearEstimator(nn.Module):
95
 
96
  return x
97
 
98
- def get_config(self) -> SystemConfig:
99
- """Get the configuration used by this estimator.
100
-
101
- Returns:
102
- SystemConfig: The configuration object
103
- """
104
- return self.config
105
-
106
  def __repr__(self) -> str:
107
  """String representation of the estimator."""
108
  return (
 
10
  import torch
11
  import torch.nn as nn
12
 
13
+ from src.config.schemas import SystemConfig, ModelConfig
14
 
15
 
16
  class LinearEstimator(nn.Module):
17
  """Learned MMSE estimator.
18
 
19
+ Find W such that W*h_pilot = h_hat, where h_hat is the estimated channel by stochastic gradient descent on |h_hat - h_ideal|^2
20
+
21
  Attributes:
22
  device (torch.device): Target device for computation
23
+ system_config (SystemConfig): Validated configuration object for OFDM system parameters
24
+ model_config (ModelConfig): Validated configuration object for model parameters
25
+ ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (num_subcarriers, num_symbols)
26
+ num_subcarriers (int): number of sub-carriers
27
+ num_symbols (int): number of OFDM symbols
28
+ pilot_size (Tuple[int, int]): Dimensions of pilot signal as (num_subcarriers, num_symbols)
29
+ num_subcarriers (int): number of pilots across sub-carriers
30
+ num_symbols (int): number of pilots across OFDM symbols
31
  """
32
 
33
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
34
  """Initialize the MMSE estimator.
35
 
36
  Args:
37
+ system_config: Validated SystemConfig object containing OFDM system parameters
38
+ model_config: Validated ModelConfig object containing model parameters
39
  """
40
  super().__init__()
41
 
42
+ self.system_config = system_config
43
+ self.model_config = model_config
44
+ self.device = torch.device(model_config.device)
45
  self.logger = logging.getLogger(__name__)
46
 
47
  # Extract dimensions from validated config
48
+ self.ofdm_size = (system_config.ofdm.num_scs, system_config.ofdm.num_symbols)
49
+ self.pilot_size = (system_config.pilot.num_scs, system_config.pilot.num_symbols)
50
 
51
  # Calculate feature dimensions
52
+ in_feature_dim = system_config.pilot.num_scs * system_config.pilot.num_symbols
53
+ out_feature_dim = system_config.ofdm.num_scs * system_config.ofdm.num_symbols
54
 
55
  self.logger.info(f"Initializing LinearEstimator:")
56
  self.logger.info(f" OFDM size: {self.ofdm_size}")
 
74
  Estimated OFDM signal tensor with shape
75
  (batch_size, ofdm_size[0], ofdm_size[1])
76
  """
77
+ # pytorch does nothing if input is already on correct device
78
  x = x.to(self.device)
79
  self.logger.debug(f"Input shape: {x.size()}")
80
 
 
99
 
100
  return x
101
 
 
 
 
 
 
 
 
 
102
  def __repr__(self) -> str:
103
  """String representation of the estimator."""
104
  return (